├── IOU.csv ├── LICENSE ├── README.md ├── checkpoint └── model-01.pt ├── dataset.py ├── images ├── loss.png ├── model.png ├── poster.png └── prediction_results01.png ├── logs.txt ├── model.py ├── plot_loss.py ├── prediction_results.csv ├── train.py └── utils.py /IOU.csv: -------------------------------------------------------------------------------- 1 | img_id,iou 2 | ef3ef194e5657fda708ecbd3eb6530286ed2ba23c88efb9f1715298975c73548,0.8198221929535726 3 | 3d0ca3498d97edebd28dbc7035eced40baa4af199af09cbb7251792accaa69fe,0.8859934853420195 4 | ff599c7301daa1f783924ac8cbe3ce7b42878f15a39c2d19659189951f540f48,0.9484287262338685 5 | 7b38c9173ebe69b4c6ba7e703c0c27f39305d9b2910f46405993d2ea7a963b80,0.0 6 | 9bb6e39d5f4415bc7554842ee5d1280403a602f2ba56122b87f453a62d37c06e,0.65300736904999 7 | 150b0ffa318c87b31d78af0e87d60390dbcd84b5f228a8c1fb3225cbe5df3e3f,0.761706533682698 8 | 52a4ac5a875be7a6c886035d54fb63f5f397dc43508c4831898f6b2f8debc7f3,0.5530534351145038 9 | 4ee5850b63549794eb3ecd3d5f5673164ac16936e36ecc3700da886e3b616149,0.8767139087933794 10 | 2abc40c118bc7303592c8bb95a80361e27560854b8971ab34dcf91966575b1f2,0.8145248380129589 11 | c96109cbebcf206f20035cbde414e43872074eee8d839ba214feed9cd36277a1,0.9277179487179488 12 | 1609b1b8480ee52652a644403b3f7d5511410a016750aa3b9a4c8ddb3e893e8e,0.9330390252516334 13 | f35ab34528e3e2d2589d24cbffc0e10024dfc474a68585d0b5feb7b05aa0067f,0.8724584103512015 14 | 3441821ebea04face181c9e2f4d0d09727c764827ac51b9e7fbadbebabeab225,0.9211937694064513 15 | e4ae1ceddb279bac30273ca7ac480025ce2e7287328f5272234b5bbca6d13135,0.8680363492690636 16 | 708eb41a3fc8f2b6cd1f529cdf38dc4ad5d5f00ad30bdcba92884f37ff78d614,0.7873008849557522 17 | 63d981a107091e1e3059102ce08870744dde173afe324bc2274c17d42f661778,0.9420451172454735 18 | ce37f6dd0615d45e66e41a8f2ed6fbc0bbe3103a290394ad474207507710eacc,0.7662665824384081 19 | f487cc82271cf84b4414552aa8b0a9d82d902451ebe8e8bc639d4121c1672ff7,0.793134025842412 20 | f4faa3a409014db1865074c5f66a0255f71ae3faba03265da0b3b91f68e8a8f0,0.8880597014925373 21 | c44ed955eb2e5c8d820b01477e122b32eff6dd475343e11229c33d8af3473b22,0.7679224065223503 22 | 6fc83b33896f58a4a067d8fdcf51f15d4ae9be05d8c3815d23336f1f2a8c45a1,0.9584692969445269 23 | 94a5a37c3b1153d5c5aef2eca53c960b9f21f2ef1758209d7ec502ec324b03a3,0.870420017873101 24 | 3bfd6bb152310f93daa6f4e1867c10572946e874b3a30c9ba8e0fcdeb590300b,0.7842712131664531 25 | 4829177d0b36abdd92c4ef0c7834cbc49f95232076bdd7e828f1f7cbb5ed80ec,0.8766928011404134 26 | 8a65e41c630d85c0004ce1772ff66fbc87aca34cb165f695255b39343fcfc832,0.7329286798179059 27 | 8ecdb93582b2d5270457b36651b62776256ade3aaa2d7432ae65c14f07432d49,0.8785046728971962 28 | c901794d1a421d52e5734500c0a2a8ca84651fb93b19cec2f411855e70cae339,0.9178572132342798 29 | a7a581e6760df4701941670e73d72533e3b0fbd7563488ad92772b41f7709710,0.7832291680034907 30 | aa83f5b4fca02ae43a6b9456ab42707b0beabc6e7c5c4e66c0d2572fb80f3615,0.8782816229116945 31 | 2ab91a4408860ae8339689ed9f87aa9359de1bdd4ca5c2eab7fff7724dbd6707,0.7707200348053078 32 | b7a86f4968071e0f963fa87ef314fdd1b6c73a66355431cc53a37e193ba6be9b,0.8940998487140696 33 | 8b77284d6f37ab3fc826139ebadaec3b9d81c552fe525c3547bbbd6c65ac0d83,0.8592342342342343 34 | 317832f90f02c5e916b2ac0f3bcb8da9928d8e400b747b2c68e544e56adacf6b,0.8743006340917568 35 | 0e21d7b3eea8cdbbed60d51d72f4f8c1974c5d76a8a3893a7d5835c85284132e,0.7514035481697732 36 | 7978812d0e2e034ee1f9c141f019705582fcaa290e4a01c6c75a62753285cb23,0.9389078900813775 37 | cf26c41245febfe67c2a1682cc4ee8752ee40ae3e49610314f45923b8bf5b08a,0.8496415770609319 38 | f7eaaf420b5204c4a42577428b7cd897a53ef07b759ccbba3ed30a3548ca5605,0.917940876656473 39 | 6af82abb29539000be4696884fc822d3cafcb2105906dc7582c92dccad8948c5,0.8204493918779633 40 | 14cc1424c59808274e123db51292e9dbb5b037ef3e7c767a8c45c9ac733b91bf,0.9411147540983607 41 | 212b858a66f0d23768b8e3e1357704fc2f4cf4bbe7eed8cd59b5d01031d553e6,0.7929607027636335 42 | a891bbc89143bca7a717386144eb061ec2d599cba24681389bcb3a2fedb8ff8c,0.8038313629084084 43 | 45cc00f2ef95da6698bf590663e319d7c0ed4fb99d42dd3cf4060887da74fb81,0.9522213654427902 44 | 0b0d577159f0d6c266f360f7b8dfde46e16fa665138bf577ec3c6f9c70c0cd1e,0.4772065955383123 45 | d2815f2f616d92be35c7e8dcfe592deec88516aef9ffc9b21257f52b7d6d0354,0.7602674307545367 46 | 29780b28e6a75fac7b96f164a1580666513199794f1b19a5df8587fe0cb59b67,0.845674740484429 47 | d2ce593bddf9998ce3b76328c0151d0ba4b644c293aca7f6254e521c448b305f,0.8646884272997033 48 | 76faaed50ed6ea6814ac36199964b86fb09ba7f41a6f213bceaa80d625adc2e1,0.7869195454203357 49 | cbca32daaae36a872a11da4eaff65d1068ff3f154eedc9d3fc0c214a4e5d32bd,0.7349389747762408 50 | b1eb0123fe2d8c825694b193efb7b923d95effac9558ee4eaf3116374c2c94fe,0.7002412545235223 51 | 20b20ab049372d184c705acebe7af026d3580f5fd5a72ed796e3622e1685af2f,0.29885991280093155 52 | aa47f0b303b1d525b52452ae3a8553b2d61d719a28aee547e2ef1fc6730a078f,0.7732030704815074 53 | 958114e5f37d5e1420b410bd716753b3e874b175f2b6958ebf1ec2bdf776e41f,0.797525899374918 54 | a02ec007ae8feddb758078b1dfb8010c26886fd3c8babdc308ead8b4a63acbdb,0.7960691873984647 55 | 4185b9369fc8bdcc7e7c68f2129b9a7442237cd0f836a4b6d13ef64bf0ef572a,0.8800443108086723 56 | 813f41ef376c3cbcc9d6e2ce6a51c2ee068226d1c1b13404eb238dcfdd447c97,0.8637226970560304 57 | fe80a2cf3c93dafad8c364fdd1646b0ba4db056cdb7bdb81474f957064812bba,0.9057526318415406 58 | 1d4a5e729bb96b08370789cad0791f6e52ce0ffe1fcc97a04046420b43c851dd,0.8260869565217391 59 | 6bc8cda54f5b66a2a27d962ac219f8075bf7cc43b87ba0c9e776404370429e80,0.9266300920670955 60 | 1d02c4b5921e916b9ddfb2f741fd6cf8d0e571ad51eb20e021c826b5fb87350e,0.8161112809538368 61 | 5afb7932e9c7328f4fb1d7a8166a3699d6cdc5192b93758a75e9956f1513c5a3,0.7585681398138842 62 | 57b49733c5a3c268b013553635a826e6a1b10e699bbd19c3b842375fe0adf344,0.9367715374558402 63 | 5d58600efa0c2667ec85595bf456a54e2bd6e6e9a5c0dff42d807bc9fe2b822e,0.1484545572563228 64 | 564fa390d9a9c26f986bf860d9091cbd84244bc1c8e3c9369f2f2e5b5fd99b92,0.8392229417206291 65 | f29fd9c52e04403cd2c7d43b6fe2479292e53b2f61969d25256d2d2aca7c6a81,0.2815791080313807 66 | bb8ebf465c968a5f6f715de5d9e2e664afd1bcaa533e0e3352ecea1cc5b6fb0d,0.6940632972488451 67 | 514ccfc78cb55988a238d3ac9dc83460aa88382c95d56bcc0559962d9fe481ef,0.8953597742592883 68 | f952cc65376009cfad8249e53b9b2c0daaa3553e897096337d143c625c2df886,0.6654926821632621 69 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Liming Wu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # UNet-pytorch 2 | 3 | ## Overview 4 | This is the code for kaggle 2018 data science bowl nuclei segmentation (https://www.kaggle.com/c/data-science-bowl-2018). We will use UNet to perform the segmentation task. 5 | 6 | ## Dependencies 7 | 8 | * numpy 9 | * scipy 10 | * tqdm 11 | * pillow 12 | * scikit-image 13 | * pytorch 14 | * pandas 15 | 16 | 17 | ## Usage 18 | 19 | 1. Download the dataset from Kaggle (https://www.kaggle.com/c/data-science-bowl-2018/data). 20 | 21 | 2. Create two folders called combined and testing_data. Run script utils.py to prepare training image and testing image, the prepared image will be inside combined and testing_data folder. 22 | 23 | 3. In class Option under script utils.py, set is_train = True and adjust three dirs and other parameters. 24 | 25 | 4. Run script train.py. The model will be saved under folder checkpoint. 26 | 27 | 5. When making prediction using testing data, set train=False in utils.py, and run script train.py again. The prediction masks will be saved to the folder specified in Option class under utils.py, and the run-length-encoding csv file will be saved in current folder. 28 | 29 | ## Training results 30 | ### U-Net Architecture 31 | ![image1](https://github.com/limingwu8/UNet-pytorch/blob/master/images/model.png) 32 | 33 | ### Some examples of prediction masks 34 | ![image2](https://github.com/limingwu8/UNet-pytorch/blob/master/images/prediction_results01.png) 35 | 36 | ### Evaluation 37 | ![image3](https://github.com/limingwu8/UNet-pytorch/blob/master/images/loss.png) -------------------------------------------------------------------------------- /checkpoint/model-01.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/limingwu8/UNet-pytorch/02f48e3a84b1857f4ad9331466cdcb30c3308b6d/checkpoint/model-01.pt -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | UNet 3 | opturations and data loading code for Kaggle Data Science Bowl 2018 4 | """ 5 | 6 | import os 7 | import numpy as np 8 | import torch 9 | from torch.utils.data import Dataset, DataLoader 10 | from torchvision import transforms, utils 11 | from skimage import io, transform 12 | from utils import Option 13 | import matplotlib.pylab as plt 14 | from sklearn.model_selection import train_test_split 15 | 16 | 17 | """Transforms: 18 | Data augmentation 19 | """ 20 | class Rescale(object): 21 | """Rescale the image in a sample to a given size. 22 | 23 | Args: 24 | output_size (tuple or int): Desired output size. If tuple, output is 25 | matched to output_size. If int, smaller of image edges is matched 26 | to output_size keeping aspect ratio the same. 27 | """ 28 | 29 | def __init__(self, output_size, train=True): 30 | assert isinstance(output_size, (int, tuple)) 31 | self.output_size = output_size 32 | self.train = train 33 | 34 | def __call__(self, sample): 35 | if self.train: 36 | image, mask, img_id, height, width = sample['image'], sample['mask'], sample['img_id'], sample['height'],sample['width'] 37 | 38 | if isinstance(self.output_size, int): 39 | new_h = new_w = self.output_size 40 | else: 41 | new_h, new_w = self.output_size 42 | 43 | new_h, new_w = int(new_h), int(new_w) 44 | 45 | # resize the image, 46 | # preserve_range means not normalize the image when resize 47 | img = transform.resize(image, (new_h, new_w), preserve_range=True, mode='constant') 48 | mask = transform.resize(mask, (new_h, new_w), preserve_range=True, mode='constant') 49 | return {'image': img, 'mask': mask, 'img_id': img_id, 'height':height, 'width':width} 50 | else: 51 | image, img_id, height,width = sample['image'], sample['img_id'], sample['height'],sample['width'] 52 | if isinstance(self.output_size, int): 53 | new_h = new_w = self.output_size 54 | else: 55 | new_h, new_w = self.output_size 56 | 57 | new_h, new_w = int(new_h), int(new_w) 58 | 59 | # resize the image, 60 | # preserve_range means not normalize the image when resize 61 | img = transform.resize(image, (new_h, new_w), preserve_range=True, mode='constant') 62 | return {'image': img, 'height': height,'width':width, 'img_id':img_id} 63 | 64 | class RandomCrop(object): 65 | """Crop randomly the image in a sample. 66 | 67 | Args: 68 | output_size (tuple or int): Desired output size. If int, square crop 69 | is made. 70 | """ 71 | 72 | def __init__(self, output_size): 73 | assert isinstance(output_size, (int, tuple)) 74 | if isinstance(output_size, int): 75 | self.output_size = (output_size, output_size) 76 | else: 77 | assert len(output_size) == 2 78 | self.output_size = output_size 79 | 80 | def __call__(self, sample): 81 | image, mask, img_id, height, width = sample['image'], sample['mask'], sample['img_id'], sample['height'], \ 82 | sample['width'] 83 | 84 | h, w = image.shape[:2] 85 | new_h, new_w = self.output_size 86 | 87 | if h - new_h > 0 and w - new_w > 0: 88 | top = np.random.randint(0, h - new_h) 89 | left = np.random.randint(0, w - new_w) 90 | else: 91 | top = 0 92 | left = 0 93 | 94 | image = image[top: top + new_h, 95 | left: left + new_w] 96 | 97 | mask = mask[top: top + new_h, 98 | left: left + new_w] 99 | 100 | return {'image': image, 'mask': mask, 'img_id':img_id, 'height':height, 'width':width} 101 | 102 | 103 | class ToTensor(object): 104 | """Convert ndarrays in sample to Tensors.""" 105 | def __init__(self, train=True): 106 | self.train = train 107 | 108 | def __call__(self, sample): 109 | if self.train: 110 | # if sample.keys 111 | image, mask, img_id, height, width = sample['image'], sample['mask'], sample['img_id'], sample['height'],sample['width'] 112 | 113 | # swap color axis because 114 | # numpy image: H x W x C 115 | # torch image: C X H X W 116 | image = image.transpose((2, 0, 1)) 117 | mask = mask.transpose((2, 0, 1)) 118 | return {'image': torch.from_numpy(image.astype(np.uint8)), 119 | 'mask': torch.from_numpy(mask.astype(np.uint8)), 120 | 'img_id': img_id, 121 | 'height':height, 122 | 'width':width} 123 | else: 124 | image, height, width, img_id = sample['image'], sample['height'],sample['width'], sample['img_id'] 125 | image = image.transpose((2, 0, 1)) 126 | return {'image': torch.from_numpy(image.astype(np.uint8)), 127 | 'height': height, 128 | 'width': width, 129 | 'img_id': img_id} 130 | 131 | # Helper function to show a batch 132 | def show_batch(sample_batched): 133 | """Show image with landmarks for a batch of samples.""" 134 | images_batch, masks_batch = sample_batched['image'].numpy().astype(np.uint8), sample_batched['mask'].numpy().astype(np.bool) 135 | batch_size = len(images_batch) 136 | for i in range(batch_size): 137 | plt.figure() 138 | plt.subplot(1, 2, 1) 139 | plt.tight_layout() 140 | plt.imshow(images_batch[i].transpose((1, 2, 0))) 141 | plt.subplot(1, 2, 2) 142 | plt.tight_layout() 143 | plt.imshow(np.squeeze(masks_batch[i].transpose((1, 2, 0)))) 144 | 145 | # Load Data Science Bowl 2018 training dataset 146 | class DSB2018Dataset(Dataset): 147 | def __init__(self, root_dir, img_id, train=True, transform=None): 148 | """ 149 | Args: 150 | :param root_dir (string): Directory with all the images 151 | :param img_id (list): lists of image id 152 | :param train: if equals true, then read training set, so the output is image, mask and imgId 153 | if equals false, then read testing set, so the output is image and imgId 154 | :param transform (callable, optional): Optional transform to be applied on a sample 155 | """ 156 | self.root_dir = root_dir 157 | self.img_id = img_id 158 | self.train = train 159 | self.transform = transform 160 | self.opt = Option() 161 | 162 | def __len__(self): 163 | return len(self.img_id) 164 | 165 | def __getitem__(self, idx): 166 | if self.train: 167 | img_dir = os.path.join(self.root_dir, self.img_id[idx], 'image.png') 168 | mask_dir = os.path.join(self.root_dir, self.img_id[idx], 'mask.png') 169 | img = io.imread(img_dir).astype(np.uint8) 170 | mask = io.imread(mask_dir, as_grey=True).astype(np.bool) 171 | mask = np.expand_dims(mask, axis=-1) 172 | sample = {'image':img, 'mask':mask, 'img_id':self.img_id[idx], "height":img.shape[0], "width":img.shape[1]} 173 | 174 | else: 175 | img_dir = os.path.join(self.root_dir, self.img_id[idx], 'image.png') 176 | img = io.imread(img_dir).astype(np.uint8) 177 | # size = (img.shape[0],img.shape[1]) # (Height, Weidth) 178 | sample = {'image': img, 'img_id': self.img_id[idx], "height":img.shape[0], "width":img.shape[1]} 179 | 180 | if self.transform: 181 | sample = self.transform(sample) 182 | 183 | return sample 184 | 185 | def get_train_valid_loader(root_dir, batch_size=16, split=True, 186 | shuffle=False, num_workers=4, val_ratio=0.1, pin_memory=False): 187 | 188 | """Utility function for loading and returning training and validation Dataloader 189 | :param root_dir: the root directory of data set 190 | :param batch_size: batch size of training and validation set 191 | :param split: if split data set to training set and validation set 192 | :param shuffle: if shuffle the image in training and validation set 193 | :param num_workers: number of workers loading the data, when using CUDA, set to 1 194 | :param val_ratio: ratio of validation set size 195 | :param pin_memory: store data in CPU pin buffer rather than memory. when using CUDA, set to True 196 | :return: 197 | if split the data set then returns: 198 | - train_loader: Dataloader for training 199 | - valid_loader: Dataloader for validation 200 | else returns: 201 | - dataloader: Dataloader of all the data set 202 | """ 203 | img_id = os.listdir(root_dir) 204 | if split: 205 | train_id, val_id = train_test_split(img_id, test_size=val_ratio) 206 | 207 | train_transformed_dataset = DSB2018Dataset(root_dir=root_dir, 208 | img_id=train_id, 209 | train=True, 210 | transform=transforms.Compose([ 211 | RandomCrop(256), 212 | Rescale(256), 213 | ToTensor() 214 | ])) 215 | val_transformed_dataset = DSB2018Dataset(root_dir=root_dir, 216 | img_id=val_id, 217 | train=True, 218 | transform=transforms.Compose([ 219 | # RandomCrop(256), # for validation set, do not use augmentation 220 | Rescale(256), 221 | ToTensor() 222 | ])) 223 | 224 | 225 | train_loader = DataLoader(train_transformed_dataset,batch_size=batch_size, 226 | shuffle=shuffle, num_workers=num_workers, pin_memory=pin_memory) 227 | val_loader = DataLoader(val_transformed_dataset, batch_size=batch_size, 228 | shuffle=False, num_workers=num_workers, pin_memory=pin_memory) 229 | return (train_loader, val_loader) 230 | else: 231 | transformed_dataset = DSB2018Dataset(root_dir=root_dir, 232 | img_id=img_id, 233 | train=True, 234 | transform=transforms.Compose([ 235 | RandomCrop(256), 236 | Rescale(256), 237 | ToTensor() 238 | ])) 239 | dataloader = DataLoader(transformed_dataset, batch_size=batch_size, 240 | shuffle=shuffle, num_workers=num_workers, pin_memory=pin_memory) 241 | return dataloader 242 | 243 | def get_test_loader(root_dir, batch_size=16, shuffle=False, num_workers=4, pin_memory=False): 244 | 245 | """Utility function for loading and returning training and validation Dataloader 246 | :param root_dir: the root directory of data set 247 | :param batch_size: batch size of training and validation set 248 | :param shuffle: if shuffle the image in training and validation set 249 | :param num_workers: number of workers loading the data, when using CUDA, set to 1 250 | :param pin_memory: store data in CPU pin buffer rather than memory. when using CUDA, set to True 251 | :return: 252 | - testloader: Dataloader of all the test set 253 | """ 254 | img_id = os.listdir(root_dir) 255 | transformed_dataset = DSB2018Dataset(root_dir=root_dir, 256 | img_id=img_id, 257 | train=False, 258 | transform=transforms.Compose([ 259 | # RandomCrop(256), 260 | Rescale(256, train=False), 261 | ToTensor(train=False) 262 | ])) 263 | testloader = DataLoader(transformed_dataset, batch_size=batch_size, 264 | shuffle=shuffle, num_workers=num_workers, pin_memory=pin_memory) 265 | return testloader 266 | 267 | if __name__ == '__main__': 268 | opt = Option() 269 | trainloader, val_loader = get_train_valid_loader(opt.root_dir, batch_size=opt.batch_size, 270 | split=True, shuffle=opt.shuffle, 271 | num_workers=opt.num_workers, 272 | val_ratio=0.1, pin_memory=opt.pin_memory) 273 | 274 | for i_batch, sample_batched in enumerate(val_loader): 275 | print(i_batch, sample_batched['image'].size(), sample_batched['mask'].size()) 276 | show_batch(sample_batched) 277 | plt.show() 278 | 279 | # testloader = get_test_loader(opt.test_dir, batch_size=opt.batch_size,shuffle=opt.shuffle, 280 | # num_workers=opt.num_workers, pin_memory=opt.pin_memory) 281 | # 282 | # for i_batch, sample_batched in enumerate(testloader): 283 | # # print(i_batch, sample_batched['image'].size(), sample_batched['img_size']) 284 | # plt.imshow(np.squeeze(sample_batched['image'][0].cpu().numpy().transpose((1, 2, 0)))) 285 | # plt.show() 286 | -------------------------------------------------------------------------------- /images/loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/limingwu8/UNet-pytorch/02f48e3a84b1857f4ad9331466cdcb30c3308b6d/images/loss.png -------------------------------------------------------------------------------- /images/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/limingwu8/UNet-pytorch/02f48e3a84b1857f4ad9331466cdcb30c3308b6d/images/model.png -------------------------------------------------------------------------------- /images/poster.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/limingwu8/UNet-pytorch/02f48e3a84b1857f4ad9331466cdcb30c3308b6d/images/poster.png -------------------------------------------------------------------------------- /images/prediction_results01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/limingwu8/UNet-pytorch/02f48e3a84b1857f4ad9331466cdcb30c3308b6d/images/prediction_results01.png -------------------------------------------------------------------------------- /logs.txt: -------------------------------------------------------------------------------- 1 | epoch: 1, loss: 0.33103285268658683 2 | epoch: 2, loss: 0.18614472023078374 3 | epoch: 3, loss: 0.1604068408764544 4 | epoch: 4, loss: 0.12726177967020444 5 | epoch: 5, loss: 0.12823953293263912 6 | epoch: 6, loss: 0.10918968135402315 7 | epoch: 7, loss: 0.11229322602351506 8 | epoch: 8, loss: 0.11859029726613135 9 | epoch: 9, loss: 0.09356918194819064 10 | epoch: 10, loss: 0.10451770937513738 11 | epoch: 11, loss: 0.09062918071590718 12 | epoch: 12, loss: 0.0898677586976971 13 | epoch: 13, loss: 0.08214519678481988 14 | epoch: 14, loss: 0.0860387058485122 15 | epoch: 15, loss: 0.08358663905944143 16 | epoch: 16, loss: 0.08870539530402138 17 | epoch: 17, loss: 0.0871602004127843 18 | epoch: 18, loss: 0.07960619058992181 19 | epoch: 19, loss: 0.07381952181458473 20 | epoch: 20, loss: 0.07774970245858033 21 | epoch: 21, loss: 0.07553141395605746 22 | epoch: 22, loss: 0.07304480698491846 23 | epoch: 23, loss: 0.07274764075520493 24 | epoch: 24, loss: 0.07348002129722209 25 | epoch: 25, loss: 0.07414966546708629 26 | epoch: 26, loss: 0.07163994298094795 27 | epoch: 27, loss: 0.07016681560448237 28 | epoch: 28, loss: 0.06719717764783473 29 | epoch: 29, loss: 0.07535205754850592 30 | epoch: 30, loss: 0.07445528917014599 31 | epoch: 31, loss: 0.07579240877003897 32 | epoch: 32, loss: 0.07733777050106298 33 | epoch: 33, loss: 0.07181127520189398 34 | epoch: 34, loss: 0.06862004509284383 35 | epoch: 35, loss: 0.06763726978429727 36 | epoch: 36, loss: 0.06839267226556937 37 | epoch: 37, loss: 0.06558844056867418 38 | epoch: 38, loss: 0.06736090876871631 39 | epoch: 39, loss: 0.06605444334092594 40 | epoch: 40, loss: 0.066352663384307 41 | epoch: 41, loss: 0.0662486424580926 42 | epoch: 42, loss: 0.0673335004775297 43 | epoch: 43, loss: 0.06662975256109521 44 | epoch: 44, loss: 0.06480027211918718 45 | epoch: 45, loss: 0.0620374060457661 46 | epoch: 46, loss: 0.062038280114176725 47 | epoch: 47, loss: 0.061097236084086556 48 | epoch: 48, loss: 0.06293723182309241 49 | epoch: 49, loss: 0.06146870314010552 50 | epoch: 50, loss: 0.06152116507291794 51 | epoch: 51, loss: 0.05911821703470889 52 | epoch: 52, loss: 0.06013673606018225 53 | epoch: 53, loss: 0.061912127992226965 54 | epoch: 54, loss: 0.06320016946466196 55 | epoch: 55, loss: 0.058656717162756694 56 | epoch: 56, loss: 0.062449097367269654 57 | epoch: 57, loss: 0.06086103405271258 58 | epoch: 58, loss: 0.05815938332428535 59 | epoch: 59, loss: 0.06166924669274262 60 | epoch: 60, loss: 0.060260106765088584 61 | epoch: 61, loss: 0.06033626859564157 62 | epoch: 62, loss: 0.062080360549901216 63 | epoch: 63, loss: 0.06002412195361796 64 | epoch: 64, loss: 0.06197611632801238 65 | epoch: 65, loss: 0.06483455250660579 66 | epoch: 66, loss: 0.06206372016597362 67 | epoch: 67, loss: 0.061084968969225883 68 | epoch: 68, loss: 0.058868170848914554 69 | epoch: 69, loss: 0.05765812289679334 70 | epoch: 70, loss: 0.05779130721376056 71 | epoch: 71, loss: 0.0607657397964171 72 | epoch: 72, loss: 0.05897572991393861 73 | epoch: 73, loss: 0.058636391358006565 74 | epoch: 74, loss: 0.05857380098175435 75 | epoch: 75, loss: 0.05805366761272862 76 | epoch: 76, loss: 0.05712666690704368 77 | epoch: 77, loss: 0.06052211672067642 78 | epoch: 78, loss: 0.061209887709646 79 | epoch: 79, loss: 0.06136812180990264 80 | epoch: 80, loss: 0.06006380117365292 81 | epoch: 81, loss: 0.0716079513409308 82 | epoch: 82, loss: 0.06876544725327265 83 | epoch: 83, loss: 0.06232867389917374 84 | epoch: 84, loss: 0.061268894134887626 85 | epoch: 85, loss: 0.06329329498112202 86 | epoch: 86, loss: 0.06317874922284059 87 | epoch: 87, loss: 0.06386063292267777 88 | epoch: 88, loss: 0.06303954709853445 89 | epoch: 89, loss: 0.06308409469645648 90 | epoch: 90, loss: 0.057948807830966655 91 | epoch: 91, loss: 0.05706781840750149 92 | epoch: 92, loss: 0.056237807940869106 93 | epoch: 93, loss: 0.05743206571787596 94 | epoch: 94, loss: 0.058953093187440006 95 | epoch: 95, loss: 0.05742329926717849 96 | epoch: 96, loss: 0.05677045908357416 97 | epoch: 97, loss: 0.05815201030955428 98 | epoch: 98, loss: 0.057192111742638406 99 | epoch: 99, loss: 0.055758211273877395 100 | epoch: 100, loss: 0.05554588218884809 101 | epoch: 101, loss: 0.05502665424275966 102 | epoch: 102, loss: 0.054793749802878926 103 | epoch: 103, loss: 0.05362787000125363 104 | epoch: 104, loss: 0.055377585103824026 105 | epoch: 105, loss: 0.054806696117988656 106 | epoch: 106, loss: 0.0540475934921276 107 | epoch: 107, loss: 0.05453269254593622 108 | epoch: 108, loss: 0.0535040781168001 109 | epoch: 109, loss: 0.05285536564354386 110 | epoch: 110, loss: 0.05456300266087055 111 | epoch: 111, loss: 0.053496070649652255 112 | epoch: 112, loss: 0.05356353587870087 113 | epoch: 113, loss: 0.0557218591372172 114 | epoch: 114, loss: 0.053410950471602735 115 | epoch: 115, loss: 0.05335276307804244 116 | epoch: 116, loss: 0.05353523112301316 117 | epoch: 117, loss: 0.05419842362226475 118 | epoch: 118, loss: 0.055108983690539994 119 | epoch: 119, loss: 0.05164731480181217 120 | epoch: 120, loss: 0.052362652584200815 121 | epoch: 121, loss: 0.05382057186216116 122 | epoch: 122, loss: 0.05473461751604364 123 | epoch: 123, loss: 0.0519766104629352 124 | epoch: 124, loss: 0.05300545382002989 125 | epoch: 125, loss: 0.053571448306597415 126 | epoch: 126, loss: 0.060119494751450564 127 | epoch: 127, loss: 0.06127685734203884 128 | epoch: 128, loss: 0.05620227088885648 129 | epoch: 129, loss: 0.05441017679515339 130 | epoch: 130, loss: 0.05430482412200598 131 | epoch: 131, loss: 0.05452641272651298 132 | epoch: 132, loss: 0.053121781952324365 133 | epoch: 133, loss: 0.055828814749561605 134 | epoch: 134, loss: 0.05450011851886908 135 | epoch: 135, loss: 0.05456519614727724 136 | epoch: 136, loss: 0.053525030258156005 137 | epoch: 137, loss: 0.05227891675063542 138 | epoch: 138, loss: 0.05290167503768489 139 | epoch: 139, loss: 0.051492033792393546 140 | epoch: 140, loss: 0.05196222250482866 141 | epoch: 141, loss: 0.05272984469220752 142 | epoch: 142, loss: 0.05157475195647705 143 | epoch: 143, loss: 0.05043308587656135 144 | epoch: 144, loss: 0.04996954920213847 145 | epoch: 145, loss: 0.05204996369069531 146 | epoch: 146, loss: 0.05367879001867203 147 | epoch: 147, loss: 0.06446008135875066 148 | epoch: 148, loss: 0.06818324699997902 149 | epoch: 149, loss: 0.06283547622816903 150 | --------------------------------------- 151 | epoch: 1, train loss: 0.015868565237541894 152 | epoch: 1, validation loss: 0.028216259248221097 153 | epoch: 2, train loss: 0.010139408841061948 154 | epoch: 2, validation loss: 0.015205004346706776 155 | epoch: 3, train loss: 0.008225415847194729 156 | epoch: 3, validation loss: 0.008067851537100316 157 | epoch: 4, train loss: 0.008024911444084364 158 | epoch: 4, validation loss: 0.009249072849355132 159 | epoch: 5, train loss: 0.007744730840966872 160 | epoch: 5, validation loss: 0.010535278037857653 161 | epoch: 6, train loss: 0.007240125894002851 162 | epoch: 6, validation loss: 0.009135314567666347 163 | epoch: 7, train loss: 0.008586826053128313 164 | epoch: 7, validation loss: 0.01121617491614957 165 | epoch: 8, train loss: 0.006951868175768338 166 | epoch: 8, validation loss: 0.006695498775684616 167 | epoch: 9, train loss: 0.005808197406690512 168 | epoch: 9, validation loss: 0.0062668067997763205 169 | epoch: 10, train loss: 0.006170906163566741 170 | epoch: 10, validation loss: 0.0070387828063411305 171 | epoch: 11, train loss: 0.005305794069225317 172 | epoch: 11, validation loss: 0.005099189308646504 173 | epoch: 12, train loss: 0.005134738308329685 174 | epoch: 12, validation loss: 0.004918459839933547 175 | epoch: 13, train loss: 0.005301122956756336 176 | epoch: 13, validation loss: 0.004773415804255265 177 | epoch: 14, train loss: 0.0050435521845299605 178 | epoch: 14, validation loss: 0.006896265895163044 179 | epoch: 15, train loss: 0.004782320372451399 180 | epoch: 15, validation loss: 0.004974040781977165 181 | epoch: 16, train loss: 0.004584505116168539 182 | epoch: 16, validation loss: 0.008983465100006876 183 | epoch: 17, train loss: 0.005600128125146056 184 | epoch: 17, validation loss: 0.005967952984433071 185 | epoch: 18, train loss: 0.005072739599257164 186 | epoch: 18, validation loss: 0.013876384065469502 187 | epoch: 19, train loss: 0.004596618907665139 188 | epoch: 19, validation loss: 0.004733207692727323 189 | epoch: 20, train loss: 0.0044614941400675035 190 | epoch: 20, validation loss: 0.005459309588024272 191 | epoch: 21, train loss: 0.004298653972534398 192 | epoch: 21, validation loss: 0.004352191185106093 193 | epoch: 22, train loss: 0.004271777918188529 194 | epoch: 22, validation loss: 0.004307542020952326 195 | epoch: 23, train loss: 0.004504601203930714 196 | epoch: 23, validation loss: 0.00431268560278475 197 | epoch: 24, train loss: 0.004703591267267863 198 | epoch: 24, validation loss: 0.006945784745821313 199 | epoch: 25, train loss: 0.005131937983518414 200 | epoch: 25, validation loss: 0.012569260525565045 201 | epoch: 26, train loss: 0.00486935664023926 202 | epoch: 26, validation loss: 0.004356067256992729 203 | epoch: 27, train loss: 0.004479155123579759 204 | epoch: 27, validation loss: 0.004629992405771221 205 | epoch: 28, train loss: 0.0043486067731779805 206 | epoch: 28, validation loss: 0.0044375104616530495 207 | epoch: 29, train loss: 0.004439866105418893 208 | epoch: 29, validation loss: 0.004441493668069887 209 | epoch: 30, train loss: 0.004233223835553103 210 | epoch: 30, validation loss: 0.004327587678311872 211 | epoch: 31, train loss: 0.0042646818517847835 212 | epoch: 31, validation loss: 0.004751737839596386 213 | epoch: 32, train loss: 0.004192098141161364 214 | epoch: 32, validation loss: 0.004364491160804557 215 | epoch: 33, train loss: 0.00420690279694932 216 | epoch: 33, validation loss: 0.005337441806522372 217 | epoch: 34, train loss: 0.0041793710729771385 218 | epoch: 34, validation loss: 0.003991966724247482 219 | epoch: 35, train loss: 0.003941906106412707 220 | epoch: 35, validation loss: 0.003873995973251352 221 | epoch: 36, train loss: 0.004056455576746025 222 | epoch: 36, validation loss: 0.004460055788802863 223 | epoch: 37, train loss: 0.0040607803175303075 224 | epoch: 37, validation loss: 0.005428318175213847 225 | epoch: 38, train loss: 0.004016927528391233 226 | epoch: 38, validation loss: 0.004102888237530517 227 | epoch: 39, train loss: 0.004023069679539398 228 | epoch: 39, validation loss: 0.004080502637607937 229 | epoch: 40, train loss: 0.004056794899415416 230 | epoch: 40, validation loss: 0.004175145824626706 231 | epoch: 41, train loss: 0.003957550842963641 232 | epoch: 41, validation loss: 0.004130467902220303 233 | epoch: 42, train loss: 0.003940383043405824 234 | epoch: 42, validation loss: 0.0038115905681207997 235 | epoch: 43, train loss: 0.0038555121731945925 236 | epoch: 43, validation loss: 0.0037562646169468734 237 | epoch: 44, train loss: 0.003825380663027613 238 | epoch: 44, validation loss: 0.0036940236589801845 239 | epoch: 45, train loss: 0.003706869766560953 240 | epoch: 45, validation loss: 0.0037209379423415877 241 | epoch: 46, train loss: 0.0036835056508753823 242 | epoch: 46, validation loss: 0.003647721519270544 243 | epoch: 47, train loss: 0.003696605139975722 244 | epoch: 47, validation loss: 0.003573315517720496 245 | epoch: 48, train loss: 0.0037001378519173284 246 | epoch: 48, validation loss: 0.00399266651045426 247 | epoch: 49, train loss: 0.003766165409935252 248 | epoch: 49, validation loss: 0.004451259087221342 249 | epoch: 50, train loss: 0.003606962040066719 250 | epoch: 50, validation loss: 0.0041758878894805115 251 | epoch: 51, train loss: 0.003727747198459916 252 | epoch: 51, validation loss: 0.003893771618110426 253 | epoch: 52, train loss: 0.003698615156497133 254 | epoch: 52, validation loss: 0.003999539164464865 255 | epoch: 53, train loss: 0.0036339263553741953 256 | epoch: 53, validation loss: 0.0035988785555064184 257 | epoch: 54, train loss: 0.003595501768277652 258 | epoch: 54, validation loss: 0.0036877754760618825 259 | epoch: 55, train loss: 0.003567304548685428 260 | epoch: 55, validation loss: 0.0034530372644587734 261 | epoch: 56, train loss: 0.003752325347950605 262 | epoch: 56, validation loss: 0.0037945273021856942 263 | epoch: 57, train loss: 0.0037075082955273426 264 | epoch: 57, validation loss: 0.003880313213043545 265 | epoch: 58, train loss: 0.0036325691233869413 266 | epoch: 58, validation loss: 0.0035438017663275622 267 | epoch: 59, train loss: 0.003799538867884805 268 | epoch: 59, validation loss: 0.00369284876543491 269 | epoch: 60, train loss: 0.0037042484935342177 270 | epoch: 60, validation loss: 0.0038397148078551536 271 | epoch: 61, train loss: 0.0036610558437520195 272 | epoch: 61, validation loss: 0.003653554001801445 273 | epoch: 62, train loss: 0.0036263314583902533 274 | epoch: 62, validation loss: 0.0040423735191622385 275 | epoch: 63, train loss: 0.00473934785149386 276 | epoch: 63, validation loss: 0.005940296651662681 277 | epoch: 64, train loss: 0.004029589808649487 278 | epoch: 64, validation loss: 0.003897702191673701 279 | epoch: 65, train loss: 0.003787595290647415 280 | epoch: 65, validation loss: 0.00408084918536357 281 | epoch: 66, train loss: 0.0037882576695030205 282 | epoch: 66, validation loss: 0.0038307640380280134 283 | epoch: 67, train loss: 0.0036764386021972298 284 | epoch: 67, validation loss: 0.003715176475087604 285 | epoch: 68, train loss: 0.003585206675242824 286 | epoch: 68, validation loss: 0.003774527046415541 287 | epoch: 69, train loss: 0.0036761433713965947 288 | epoch: 69, validation loss: 0.003582546833414541 289 | epoch: 70, train loss: 0.0035914383878063406 290 | epoch: 70, validation loss: 0.003740703260794801 291 | epoch: 71, train loss: 0.0035232889773339575 292 | epoch: 71, validation loss: 0.0038703204801377174 293 | epoch: 72, train loss: 0.0035068604378508494 294 | epoch: 72, validation loss: 0.0038598365803707894 295 | epoch: 73, train loss: 0.003428281925828698 296 | epoch: 73, validation loss: 0.003419407657278118 297 | epoch: 74, train loss: 0.0034155379332120148 298 | epoch: 74, validation loss: 0.0034452481735899277 299 | epoch: 75, train loss: 0.0034640076574846287 300 | epoch: 75, validation loss: 0.0034463037651767383 301 | epoch: 76, train loss: 0.0034322108251736137 302 | epoch: 76, validation loss: 0.00355796576794206 303 | epoch: 77, train loss: 0.0033291662497998867 304 | epoch: 77, validation loss: 0.003442790056540203 305 | epoch: 78, train loss: 0.003361780208784154 306 | epoch: 78, validation loss: 0.0033975621842922857 307 | epoch: 79, train loss: 0.0033002618521105987 308 | epoch: 79, validation loss: 0.0033833555766006012 309 | epoch: 80, train loss: 0.0036026869945314593 310 | epoch: 80, validation loss: 0.0037130868775927606 311 | epoch: 81, train loss: 0.0037469105427439136 312 | epoch: 81, validation loss: 0.004036704254387623 313 | epoch: 82, train loss: 0.003578352366588009 314 | epoch: 82, validation loss: 0.003541765231052244 315 | epoch: 83, train loss: 0.0033991020821516785 316 | epoch: 83, validation loss: 0.003471256989003414 317 | epoch: 84, train loss: 0.0033880100488217907 318 | epoch: 84, validation loss: 0.0035799554183115413 319 | epoch: 85, train loss: 0.0034691731060915326 320 | epoch: 85, validation loss: 0.00355363676129882 321 | epoch: 86, train loss: 0.003491300341995992 322 | epoch: 86, validation loss: 0.0034048402091954675 323 | epoch: 87, train loss: 0.003446070755373186 324 | epoch: 87, validation loss: 0.0034081912313774845 325 | epoch: 88, train loss: 0.0033001575625160244 326 | epoch: 88, validation loss: 0.0035287415729233283 327 | epoch: 89, train loss: 0.003388737561567308 328 | epoch: 89, validation loss: 0.0033009225320361343 329 | epoch: 90, train loss: 0.003287048367891541 330 | epoch: 90, validation loss: 0.003331221701889291 331 | epoch: 91, train loss: 0.0033512271069620378 332 | epoch: 91, validation loss: 0.0033537619816723153 333 | epoch: 92, train loss: 0.003350328573737769 334 | epoch: 92, validation loss: 0.0033697894624562606 335 | epoch: 93, train loss: 0.0032765013513280385 336 | epoch: 93, validation loss: 0.003465234186148169 337 | epoch: 94, train loss: 0.003347681381216097 338 | epoch: 94, validation loss: 0.0033113304000489947 339 | epoch: 95, train loss: 0.003459907314472926 340 | epoch: 95, validation loss: 0.003519280700887218 341 | epoch: 96, train loss: 0.0034590805347879134 342 | epoch: 96, validation loss: 0.0038187463349271967 343 | epoch: 97, train loss: 0.004289287103843531 344 | epoch: 97, validation loss: 0.004910788064819466 345 | epoch: 98, train loss: 0.004398663208555819 346 | epoch: 98, validation loss: 0.004304993514644961 347 | epoch: 99, train loss: 0.003898589932058581 348 | epoch: 99, validation loss: 0.0037296141681584156 349 | epoch: 100, train loss: 0.00383790224716438 350 | epoch: 100, validation loss: 0.0038432876886815375 351 | epoch: 101, train loss: 0.0037744711361714263 352 | epoch: 101, validation loss: 0.0036842571895810503 353 | epoch: 102, train loss: 0.003669828571885775 354 | epoch: 102, validation loss: 0.003570333805868084 355 | epoch: 103, train loss: 0.0035060204899133142 356 | epoch: 103, validation loss: 0.003324545580109158 357 | epoch: 104, train loss: 0.0033778063304240433 358 | epoch: 104, validation loss: 0.0033709382851473727 359 | epoch: 105, train loss: 0.0033193003713689237 360 | epoch: 105, validation loss: 0.003302225685534786 361 | epoch: 106, train loss: 0.0033279100361945815 362 | epoch: 106, validation loss: 0.0038131236980605875 363 | epoch: 107, train loss: 0.0033634215852810967 364 | epoch: 107, validation loss: 0.003232005470823095 365 | epoch: 108, train loss: 0.003266342744157089 366 | epoch: 108, validation loss: 0.0033438821658368528 367 | epoch: 109, train loss: 0.0031916999302890963 368 | epoch: 109, validation loss: 0.003120053717078854 369 | epoch: 110, train loss: 0.0031573023634714078 370 | epoch: 110, validation loss: 0.0030617243540820792 371 | epoch: 111, train loss: 0.0031441925294314254 372 | epoch: 111, validation loss: 0.003155539244734628 373 | epoch: 112, train loss: 0.003145395124803728 374 | epoch: 112, validation loss: 0.0032297498229921953 375 | epoch: 113, train loss: 0.0032332420262630108 376 | epoch: 113, validation loss: 0.0032695726851375145 377 | epoch: 114, train loss: 0.00315697087441412 378 | epoch: 114, validation loss: 0.0031666297057819605 379 | epoch: 115, train loss: 0.0030964426109397394 380 | epoch: 115, validation loss: 0.003181477977738254 381 | epoch: 116, train loss: 0.003152570552864478 382 | epoch: 116, validation loss: 0.0030843412401664314 383 | epoch: 117, train loss: 0.003168282943899754 384 | epoch: 117, validation loss: 0.003308484197947912 385 | epoch: 118, train loss: 0.003125005004739089 386 | epoch: 118, validation loss: 0.003085747018308189 387 | epoch: 119, train loss: 0.003085770997933883 388 | epoch: 119, validation loss: 0.0030297806106544846 389 | epoch: 120, train loss: 0.002978192308723037 390 | epoch: 120, validation loss: 0.0029732626592539634 391 | epoch: 121, train loss: 0.0030812027973816367 392 | epoch: 121, validation loss: 0.002973541259073697 393 | epoch: 122, train loss: 0.003107841260990693 394 | epoch: 122, validation loss: 0.002988897196046551 395 | epoch: 123, train loss: 0.0031088023673539137 396 | epoch: 123, validation loss: 0.00300009371357573 397 | epoch: 124, train loss: 0.0029326244280557727 398 | epoch: 124, validation loss: 0.0030868856597747376 399 | epoch: 125, train loss: 0.0031077278033516697 400 | epoch: 125, validation loss: 0.0034489586183285435 401 | epoch: 126, train loss: 0.0030652917952432757 402 | epoch: 126, validation loss: 0.003061610680215592 403 | epoch: 127, train loss: 0.0030067113921614626 404 | epoch: 127, validation loss: 0.003241096510494724 405 | epoch: 128, train loss: 0.0030760700639899494 406 | epoch: 128, validation loss: 0.0030201042617731427 407 | epoch: 129, train loss: 0.002984120335992098 408 | epoch: 129, validation loss: 0.0028941182181807497 409 | epoch: 130, train loss: 0.003071463350237503 410 | epoch: 130, validation loss: 0.0035841906409259656 411 | epoch: 131, train loss: 0.0029401432284174075 412 | epoch: 131, validation loss: 0.0028866378877738224 413 | epoch: 132, train loss: 0.0029267293692969566 414 | epoch: 132, validation loss: 0.0030343524808956815 415 | epoch: 133, train loss: 0.002814759279463817 416 | epoch: 133, validation loss: 0.0028093447028404443 417 | epoch: 134, train loss: 0.002831698779887821 418 | epoch: 134, validation loss: 0.0030316414981833343 419 | epoch: 135, train loss: 0.0027628995688202766 420 | epoch: 135, validation loss: 0.0028037832058336013 421 | epoch: 136, train loss: 0.00296821598439272 422 | epoch: 136, validation loss: 0.003517935222407083 423 | epoch: 137, train loss: 0.002878547125614895 424 | epoch: 137, validation loss: 0.0028388469540212877 425 | epoch: 138, train loss: 0.002871210221973422 426 | epoch: 138, validation loss: 0.002750325793551766 427 | epoch: 139, train loss: 0.0027445503209360795 428 | epoch: 139, validation loss: 0.002764488188905107 429 | epoch: 140, train loss: 0.002761790087171653 430 | epoch: 140, validation loss: 0.0032241577792226973 431 | epoch: 141, train loss: 0.0028173711157062557 432 | epoch: 141, validation loss: 0.0029042889402675787 433 | epoch: 142, train loss: 0.0026696408553601894 434 | epoch: 142, validation loss: 0.002795181664982641 435 | epoch: 143, train loss: 0.002727247561462483 436 | epoch: 143, validation loss: 0.002797099438868154 437 | epoch: 144, train loss: 0.002723320583453028 438 | epoch: 144, validation loss: 0.002835636204772723 439 | epoch: 145, train loss: 0.003106545209019734 440 | epoch: 145, validation loss: 0.004898942864900007 441 | epoch: 146, train loss: 0.003360044754806838 442 | epoch: 146, validation loss: 0.00351576272948069 443 | epoch: 147, train loss: 0.0032984263991093755 444 | epoch: 147, validation loss: 0.00306284255896437 445 | epoch: 148, train loss: 0.00298782339810732 446 | epoch: 148, validation loss: 0.0030739469141707098 447 | epoch: 149, train loss: 0.002800058246054262 448 | epoch: 149, validation loss: 0.0028533103410926823 449 | epoch: 1, train loss: 0.014265558884709234 450 | epoch: 1, train loss: 0.014552166781219875 451 | epoch: 1, train loss: 0.02002907928641558 452 | epoch: 1, validation loss: 0.04845063721955712 453 | epoch: 2, train loss: 0.00943551088990659 454 | epoch: 2, validation loss: 0.012177029207571228 455 | epoch: 3, train loss: 0.007378399878690887 456 | epoch: 3, validation loss: 0.006984331825775887 457 | epoch: 4, train loss: 0.006415546236939691 458 | epoch: 4, validation loss: 0.005861550347128911 459 | epoch: 5, train loss: 0.006704387411005659 460 | epoch: 5, validation loss: 0.006022822612257146 461 | epoch: 6, train loss: 0.005865540843500229 462 | epoch: 6, validation loss: 0.009845603090613636 463 | epoch: 7, train loss: 0.0048695479521209725 464 | epoch: 7, validation loss: 0.005387518944135352 465 | epoch: 8, train loss: 0.0045631017356765015 466 | epoch: 8, validation loss: 0.005340935022973303 467 | epoch: 9, train loss: 0.004519248921825716 468 | epoch: 9, validation loss: 0.006771967927021767 469 | epoch: 10, train loss: 0.00444973809382018 470 | epoch: 10, validation loss: 0.0045396568837450515 471 | epoch: 11, train loss: 0.004459794678695957 472 | epoch: 11, validation loss: 0.004452567047147609 473 | epoch: 12, train loss: 0.004229764421632634 474 | epoch: 12, validation loss: 0.004204860262906373 475 | epoch: 13, train loss: 0.0042405917771024685 476 | epoch: 13, validation loss: 0.00587329007129171 477 | epoch: 14, train loss: 0.004518817084691615 478 | epoch: 14, validation loss: 0.004882918809776876 479 | epoch: 15, train loss: 0.004403421435980852 480 | epoch: 15, validation loss: 0.006382592689635149 481 | epoch: 16, train loss: 0.004107157994759815 482 | epoch: 16, validation loss: 0.00430423938738766 483 | epoch: 17, train loss: 0.004177126959336931 484 | epoch: 17, validation loss: 0.003959751952050337 485 | epoch: 18, train loss: 0.004054951257571257 486 | epoch: 18, validation loss: 0.004001803640554201 487 | epoch: 19, train loss: 0.004045970447571519 488 | epoch: 19, validation loss: 0.005039996175623652 489 | epoch: 20, train loss: 0.0038808539361502993 490 | epoch: 20, validation loss: 0.003985968988333175 491 | epoch: 21, train loss: 0.0038767083923318493 492 | epoch: 21, validation loss: 0.00436231104740456 493 | epoch: 22, train loss: 0.0038021871971451425 494 | epoch: 22, validation loss: 0.003995066202843367 495 | epoch: 23, train loss: 0.003771199806165537 496 | epoch: 23, validation loss: 0.005740766538612878 497 | epoch: 24, train loss: 0.003725990300896156 498 | epoch: 24, validation loss: 0.003951248170724555 499 | epoch: 25, train loss: 0.0037854561554407007 500 | epoch: 25, validation loss: 0.004472242251260957 501 | epoch: 26, train loss: 0.0040117106392707795 502 | epoch: 26, validation loss: 0.003941103093214889 503 | epoch: 27, train loss: 0.0038860166386684176 504 | epoch: 27, validation loss: 0.0036692149492342082 505 | epoch: 28, train loss: 0.0038239430716480584 506 | epoch: 28, validation loss: 0.0037854675815176607 507 | epoch: 29, train loss: 0.003757181341078744 508 | epoch: 29, validation loss: 0.0038420916048448476 509 | epoch: 30, train loss: 0.003854892406594101 510 | epoch: 30, validation loss: 0.0036999444240954386 511 | epoch: 31, train loss: 0.00357691342533129 512 | epoch: 31, validation loss: 0.003567967341462178 513 | epoch: 32, train loss: 0.0037864248884554525 514 | epoch: 32, validation loss: 0.003786036661311762 515 | epoch: 33, train loss: 0.003692831053316692 516 | epoch: 33, validation loss: 0.003941342457016902 517 | epoch: 34, train loss: 0.0037123523837891384 518 | epoch: 34, validation loss: 0.003680602550061781 519 | epoch: 35, train loss: 0.003727554607747206 520 | epoch: 35, validation loss: 0.003720442861763399 521 | epoch: 36, train loss: 0.003599712652946586 522 | epoch: 36, validation loss: 0.0036742352505228413 523 | epoch: 37, train loss: 0.0036109596317878013 524 | epoch: 37, validation loss: 0.003674289628640929 525 | epoch: 38, train loss: 0.003980133867827221 526 | epoch: 38, validation loss: 0.0037984492173835415 527 | epoch: 39, train loss: 0.004768742097056722 528 | epoch: 39, validation loss: 0.018835161159287638 529 | epoch: 40, train loss: 0.004142572865458468 530 | epoch: 40, validation loss: 0.004096289512826435 531 | epoch: 41, train loss: 0.003960762426232421 532 | epoch: 41, validation loss: 0.004200280602298566 533 | epoch: 42, train loss: 0.003839832800092982 534 | epoch: 42, validation loss: 0.003892338487194545 535 | epoch: 43, train loss: 0.003841231124565178 536 | epoch: 43, validation loss: 0.004115321798555886 537 | epoch: 44, train loss: 0.003568481690996322 538 | epoch: 44, validation loss: 0.004231223808740502 539 | epoch: 45, train loss: 0.003553380920964094 540 | epoch: 45, validation loss: 0.003594165584489481 541 | epoch: 46, train loss: 0.003625074052444936 542 | epoch: 46, validation loss: 0.003647049249552969 543 | epoch: 47, train loss: 0.0034793101399395597 544 | epoch: 47, validation loss: 0.003648280263391893 545 | epoch: 48, train loss: 0.0035523947383930433 546 | epoch: 48, validation loss: 0.003983998754576071 547 | epoch: 49, train loss: 0.0035050253308135677 548 | epoch: 49, validation loss: 0.0035776277420236104 549 | epoch: 50, train loss: 0.003521569814551529 550 | epoch: 50, validation loss: 0.0036355469431450117 551 | epoch: 51, train loss: 0.003532189431030359 552 | epoch: 51, validation loss: 0.003882765547553105 553 | epoch: 52, train loss: 0.003476262914254693 554 | epoch: 52, validation loss: 0.003612111920296256 555 | epoch: 53, train loss: 0.00391253787610266 556 | epoch: 53, validation loss: 0.007566547327077211 557 | epoch: 54, train loss: 0.003880762984701255 558 | epoch: 54, validation loss: 0.00410717286503137 559 | epoch: 55, train loss: 0.003527669455380384 560 | epoch: 55, validation loss: 0.0046990529369952076 561 | epoch: 56, train loss: 0.003466639210582768 562 | epoch: 56, validation loss: 0.003687891664344873 563 | epoch: 57, train loss: 0.003428840215872374 564 | epoch: 57, validation loss: 0.0035701193360250386 565 | epoch: 58, train loss: 0.0034926861561649473 566 | epoch: 58, validation loss: 0.003487263097247081 567 | epoch: 59, train loss: 0.003360906779568389 568 | epoch: 59, validation loss: 0.003969762696703868 569 | epoch: 60, train loss: 0.003359914196664421 570 | epoch: 60, validation loss: 0.003523121351626382 571 | epoch: 61, train loss: 0.0034660407578322424 572 | epoch: 61, validation loss: 0.0035355872849919903 573 | epoch: 62, train loss: 0.003706625339947332 574 | epoch: 62, validation loss: 0.003769942851208929 575 | epoch: 63, train loss: 0.0034576555896260056 576 | epoch: 63, validation loss: 0.0034354441979927803 577 | epoch: 64, train loss: 0.0033033622790183594 578 | epoch: 64, validation loss: 0.003423591992303507 579 | epoch: 65, train loss: 0.003287611628077912 580 | epoch: 65, validation loss: 0.003642299949233212 581 | epoch: 66, train loss: 0.003328352190467651 582 | epoch: 66, validation loss: 0.004014632205909758 583 | epoch: 67, train loss: 0.0037257384144350468 584 | epoch: 67, validation loss: 0.003817385535186796 585 | epoch: 68, train loss: 0.003524832004313643 586 | epoch: 68, validation loss: 0.0036396960825172825 587 | epoch: 69, train loss: 0.003418319143515519 588 | epoch: 69, validation loss: 0.003650760361507757 589 | epoch: 70, train loss: 0.003433951480965907 590 | epoch: 70, validation loss: 0.0035183449178489284 591 | epoch: 71, train loss: 0.003398108084502307 592 | epoch: 71, validation loss: 0.003438008309745077 593 | epoch: 72, train loss: 0.0033286804639136615 594 | epoch: 72, validation loss: 0.003760237524758524 595 | epoch: 73, train loss: 0.0035008724568841073 596 | epoch: 73, validation loss: 0.003510281499197234 597 | epoch: 74, train loss: 0.003255145089889245 598 | epoch: 74, validation loss: 0.0037097219766965553 599 | epoch: 75, train loss: 0.0033095482038720133 600 | epoch: 75, validation loss: 0.003884773312219933 601 | epoch: 76, train loss: 0.003136999806262565 602 | epoch: 76, validation loss: 0.0033873589292391024 603 | epoch: 77, train loss: 0.0033422356431262804 604 | epoch: 77, validation loss: 0.0036874922242627214 605 | epoch: 78, train loss: 0.00332638796783403 606 | epoch: 78, validation loss: 0.003438512558367715 607 | epoch: 79, train loss: 0.003212812491020753 608 | epoch: 79, validation loss: 0.0034266616871107872 609 | epoch: 80, train loss: 0.0032091316694446265 610 | epoch: 80, validation loss: 0.003443409702671108 611 | epoch: 81, train loss: 0.003121261217725613 612 | epoch: 81, validation loss: 0.003526950838850505 613 | epoch: 82, train loss: 0.003093621451067885 614 | epoch: 82, validation loss: 0.003360453492669917 615 | epoch: 83, train loss: 0.00315153128954012 616 | epoch: 83, validation loss: 0.0034537090183194004 617 | epoch: 84, train loss: 0.0030831101954378695 618 | epoch: 84, validation loss: 0.003467225928359957 619 | epoch: 85, train loss: 0.0033467403469394095 620 | epoch: 85, validation loss: 0.003596666032698617 621 | epoch: 86, train loss: 0.0032430638869603476 622 | epoch: 86, validation loss: 0.003372374588429038 623 | epoch: 87, train loss: 0.0033472229172143574 624 | epoch: 87, validation loss: 0.007170477131409432 625 | epoch: 88, train loss: 0.0035891169450472838 626 | epoch: 88, validation loss: 0.004864927770486518 627 | epoch: 89, train loss: 0.003655438967802829 628 | epoch: 89, validation loss: 0.003942957231357916 629 | epoch: 90, train loss: 0.00345227416376174 630 | epoch: 90, validation loss: 0.004809244466361715 631 | epoch: 91, train loss: 0.003363086834278075 632 | epoch: 91, validation loss: 0.0037518125305424873 633 | epoch: 92, train loss: 0.0031719649794386393 634 | epoch: 92, validation loss: 0.0034810983534179518 635 | epoch: 93, train loss: 0.003235728330107075 636 | epoch: 93, validation loss: 0.0037459481451938403 637 | epoch: 94, train loss: 0.0032925994798022125 638 | epoch: 94, validation loss: 0.0034126608785408647 639 | epoch: 95, train loss: 0.003304689603658458 640 | epoch: 95, validation loss: 0.0035593838834050875 641 | epoch: 96, train loss: 0.003109529073879889 642 | epoch: 96, validation loss: 0.003432340975572814 643 | epoch: 97, train loss: 0.003260116702931042 644 | epoch: 97, validation loss: 0.003545890548335972 645 | epoch: 98, train loss: 0.0030982078826842616 646 | epoch: 98, validation loss: 0.003507545023266949 647 | epoch: 99, train loss: 0.0031256152511532627 648 | epoch: 99, validation loss: 0.00330268758446423 649 | epoch: 100, train loss: 0.003114136346734776 650 | epoch: 100, validation loss: 0.003477171897443373 651 | epoch: 101, train loss: 0.003127097360678573 652 | epoch: 101, validation loss: 0.003222222648449798 653 | epoch: 102, train loss: 0.0030163543462555604 654 | epoch: 102, validation loss: 0.003223189778292357 655 | epoch: 103, train loss: 0.0030038970712553802 656 | epoch: 103, validation loss: 0.003333435796979648 657 | epoch: 104, train loss: 0.0032102638364431275 658 | epoch: 104, validation loss: 0.003202551503234835 659 | epoch: 105, train loss: 0.0030864900530966164 660 | epoch: 105, validation loss: 0.0033982802015631946 661 | epoch: 106, train loss: 0.0031764730872898355 662 | epoch: 106, validation loss: 0.0033982676912599533 663 | epoch: 107, train loss: 0.003037628583350585 664 | epoch: 107, validation loss: 0.003638269963549144 665 | epoch: 108, train loss: 0.0030540223133010453 666 | epoch: 108, validation loss: 0.003394437203211571 667 | epoch: 109, train loss: 0.003064653528838806 668 | epoch: 109, validation loss: 0.0035086989180365606 669 | epoch: 110, train loss: 0.002941680661283718 670 | epoch: 110, validation loss: 0.0031868302221618483 671 | epoch: 111, train loss: 0.0029396996470430793 672 | epoch: 111, validation loss: 0.0032729421978566183 673 | epoch: 112, train loss: 0.002913912563626446 674 | epoch: 112, validation loss: 0.0031965944860408557 675 | epoch: 113, train loss: 0.002841824878804126 676 | epoch: 113, validation loss: 0.0035070890366141476 677 | epoch: 114, train loss: 0.002876807680696397 678 | epoch: 114, validation loss: 0.003264243589408362 679 | epoch: 115, train loss: 0.0028847467716455856 680 | epoch: 115, validation loss: 0.0034375124569259474 681 | epoch: 116, train loss: 0.0028724296027748146 682 | epoch: 116, validation loss: 0.0032824363503883135 683 | epoch: 117, train loss: 0.0028163511336937075 684 | epoch: 117, validation loss: 0.00320189451770996 685 | epoch: 118, train loss: 0.002775617611052385 686 | epoch: 118, validation loss: 0.003290733072295118 687 | epoch: 119, train loss: 0.002858436025469062 688 | epoch: 119, validation loss: 0.003260946095879398 689 | epoch: 120, train loss: 0.002899168525663379 690 | epoch: 120, validation loss: 0.0032038555772446876 691 | epoch: 121, train loss: 0.002904616376604409 692 | epoch: 121, validation loss: 0.003241114373972167 693 | epoch: 122, train loss: 0.0027990634761639495 694 | epoch: 122, validation loss: 0.0032603538303232903 695 | epoch: 123, train loss: 0.0027882734786218674 696 | epoch: 123, validation loss: 0.003174903288261214 697 | epoch: 124, train loss: 0.0028506474844357662 698 | epoch: 124, validation loss: 0.003207363243868102 699 | epoch: 125, train loss: 0.002870516996725677 700 | epoch: 125, validation loss: 0.0032620892595888965 701 | epoch: 126, train loss: 0.0027225399745093846 702 | epoch: 126, validation loss: 0.003242547665514163 703 | epoch: 127, train loss: 0.002790182140064279 704 | epoch: 127, validation loss: 0.003290299492985455 705 | epoch: 128, train loss: 0.002742089355861765 706 | epoch: 128, validation loss: 0.003339979964405743 707 | epoch: 129, train loss: 0.0027358706107681267 708 | epoch: 129, validation loss: 0.0033830830187939888 709 | epoch: 130, train loss: 0.0028481129155329015 710 | epoch: 130, validation loss: 0.0032921605256956013 711 | epoch: 131, train loss: 0.0027999690645814537 712 | epoch: 131, validation loss: 0.003387759314544165 713 | epoch: 132, train loss: 0.0027403715089778995 714 | epoch: 132, validation loss: 0.003385363952882254 715 | epoch: 133, train loss: 0.0026409481103789943 716 | epoch: 133, validation loss: 0.003198508618038092 717 | epoch: 134, train loss: 0.0027662072837303328 718 | epoch: 134, validation loss: 0.0034163475259026484 719 | epoch: 135, train loss: 0.0027240659318753142 720 | epoch: 135, validation loss: 0.0033823887247647812 721 | epoch: 136, train loss: 0.002799947398589618 722 | epoch: 136, validation loss: 0.003250602299152915 723 | epoch: 137, train loss: 0.0027371851903437382 724 | epoch: 137, validation loss: 0.003420100505672284 725 | epoch: 138, train loss: 0.0026950765042805156 726 | epoch: 138, validation loss: 0.003256491649506697 727 | epoch: 139, train loss: 0.002753994342674861 728 | epoch: 139, validation loss: 0.003453834566162593 729 | epoch: 140, train loss: 0.0027248317415935683 730 | epoch: 140, validation loss: 0.003257149079842354 731 | epoch: 141, train loss: 0.002639220490013782 732 | epoch: 141, validation loss: 0.0036224467652057534 733 | epoch: 142, train loss: 0.002718355784665293 734 | epoch: 142, validation loss: 0.003329496385891046 735 | epoch: 143, train loss: 0.002657597025803863 736 | epoch: 143, validation loss: 0.0033558544827930964 737 | epoch: 144, train loss: 0.0029867220373444294 738 | epoch: 144, validation loss: 0.0033276590949563836 739 | epoch: 145, train loss: 0.0029870045354364327 740 | epoch: 145, validation loss: 0.003258352015000671 741 | epoch: 146, train loss: 0.0027821362389260857 742 | epoch: 146, validation loss: 0.0035879471297584364 743 | epoch: 147, train loss: 0.0026906180947673065 744 | epoch: 147, validation loss: 0.0034786394950169237 745 | epoch: 148, train loss: 0.002719710925424079 746 | epoch: 148, validation loss: 0.003529310615649864 747 | epoch: 149, train loss: 0.0027294801378230354 748 | epoch: 149, validation loss: 0.003384117370665963 749 | epoch: 1, train loss: 23.96291160583496 750 | epoch: 1, validation loss: 24.57157982720269 751 | epoch: 2, train loss: 24.00279900902196 752 | epoch: 2, validation loss: 24.57157982720269 753 | epoch: 3, train loss: 23.982500678614567 754 | epoch: 3, validation loss: 24.57157982720269 755 | epoch: 4, train loss: 23.97407072468808 756 | epoch: 4, validation loss: 24.57157982720269 757 | epoch: 1, train loss: 0.3828646013219106 758 | epoch: 1, validation loss: 0.2675665583875444 759 | epoch: 2, train loss: 0.19970255786258923 760 | epoch: 2, validation loss: 0.2204214595258236 761 | epoch: 3, train loss: 0.1795755325767555 762 | epoch: 3, validation loss: 0.19747109711170197 763 | epoch: 4, train loss: 0.15518866288230607 764 | epoch: 4, validation loss: 0.20600702779160607 765 | epoch: 5, train loss: 0.15695916890705885 766 | epoch: 5, validation loss: 0.13578750358687508 767 | epoch: 6, train loss: 0.15654674025350496 768 | epoch: 6, validation loss: 0.612781337565846 769 | epoch: 7, train loss: 0.13923000436472266 770 | epoch: 7, validation loss: 0.17213589284155104 771 | epoch: 8, train loss: 0.12682433534217508 772 | epoch: 8, validation loss: 0.11443191435601976 773 | epoch: 9, train loss: 0.11357639180986505 774 | epoch: 9, validation loss: 0.10207787859770986 775 | epoch: 10, train loss: 0.11205616755116928 776 | epoch: 10, validation loss: 0.14000267245703274 777 | epoch: 11, train loss: 0.10757422780520037 778 | epoch: 11, validation loss: 0.23080816864967346 779 | epoch: 12, train loss: 0.11219737267023638 780 | epoch: 12, validation loss: 0.16613530864318213 781 | epoch: 13, train loss: 0.10527363110725817 782 | epoch: 13, validation loss: 0.08699872096379598 783 | epoch: 14, train loss: 0.09305648990955792 784 | epoch: 14, validation loss: 0.09639191544718212 785 | epoch: 15, train loss: 0.08872536687474501 786 | epoch: 15, validation loss: 0.19343752413988113 787 | epoch: 16, train loss: 0.0869163348290481 788 | epoch: 16, validation loss: 0.11840159073472023 789 | epoch: 17, train loss: 0.09242151344293043 790 | epoch: 17, validation loss: 0.24112444950474632 791 | epoch: 18, train loss: 0.09186348833731915 792 | epoch: 18, validation loss: 0.09183072505725755 793 | epoch: 19, train loss: 0.09125518367478722 794 | epoch: 19, validation loss: 0.7025548554956913 795 | epoch: 20, train loss: 0.08620007502797403 796 | epoch: 20, validation loss: 0.0891111869778898 797 | epoch: 21, train loss: 0.08921212389280922 798 | epoch: 21, validation loss: 0.17353802464074558 799 | epoch: 22, train loss: 0.08554875733036744 800 | epoch: 22, validation loss: 0.0915277289847533 801 | epoch: 23, train loss: 0.07958827584393714 802 | epoch: 23, validation loss: 0.08311219182279375 803 | epoch: 24, train loss: 0.07763194849126433 804 | epoch: 24, validation loss: 0.096598572201199 805 | epoch: 25, train loss: 0.08032672891491338 806 | epoch: 25, validation loss: 0.15292669667137992 807 | epoch: 26, train loss: 0.12857376436065687 808 | epoch: 26, validation loss: 0.1319846875137753 809 | epoch: 27, train loss: 0.09679312356992771 810 | epoch: 27, validation loss: 0.18178077911337218 811 | epoch: 28, train loss: 0.08288398904627875 812 | epoch: 28, validation loss: 0.1230815156466431 813 | epoch: 29, train loss: 0.08203412658583961 814 | epoch: 29, validation loss: 0.09432533631722133 815 | epoch: 1, train loss: 0.2740719028209385 816 | epoch: 1, validation loss: 0.21813792652553982 817 | epoch: 2, train loss: 0.18967716160573458 818 | epoch: 2, validation loss: 0.1845573435227076 819 | epoch: 3, train loss: 0.16510677818012864 820 | epoch: 3, validation loss: 0.18309610999292797 821 | epoch: 4, train loss: 0.16156289226522572 822 | epoch: 4, validation loss: 0.39732807377974194 823 | epoch: 5, train loss: 0.15167616895939173 824 | epoch: 5, validation loss: 0.1611499802933799 825 | epoch: 1, train loss: 0.277186408372862 826 | epoch: 2, train loss: 0.18837885753739447 827 | epoch: 3, train loss: 0.17427526804662885 828 | epoch: 4, train loss: 0.15558438189327717 829 | epoch: 5, train loss: 0.14636531024284305 830 | epoch: 6, train loss: 0.14941920056229546 831 | epoch: 7, train loss: 0.13422018134345612 832 | epoch: 8, train loss: 0.1252529135949555 833 | epoch: 9, train loss: 0.12221751849920977 834 | epoch: 10, train loss: 0.10997288208454847 835 | epoch: 11, train loss: 0.11389069103946288 836 | epoch: 12, train loss: 0.10401701590134985 837 | epoch: 13, train loss: 0.08325865197305878 838 | epoch: 14, train loss: 0.08751861356376182 839 | epoch: 15, train loss: 0.08291996811472234 840 | epoch: 16, train loss: 0.0792127731256187 841 | epoch: 17, train loss: 0.07943389079134379 842 | epoch: 18, train loss: 0.0788346165347667 843 | epoch: 19, train loss: 0.08283437455871276 844 | epoch: 20, train loss: 0.078623670419412 845 | epoch: 21, train loss: 0.07709836627223662 846 | epoch: 22, train loss: 0.07235375574479501 847 | epoch: 23, train loss: 0.07104153946662943 848 | epoch: 24, train loss: 0.06955032298962276 849 | epoch: 25, train loss: 0.06981704253259868 850 | epoch: 26, train loss: 0.07086721226750385 851 | epoch: 27, train loss: 0.0734735892065579 852 | epoch: 28, train loss: 0.0717286936761368 853 | epoch: 29, train loss: 0.06950068850779817 854 | epoch: 30, train loss: 0.06794847816317565 855 | epoch: 31, train loss: 0.0719312961612429 856 | epoch: 32, train loss: 0.0746229820042139 857 | epoch: 33, train loss: 0.071206424562704 858 | epoch: 34, train loss: 0.06894200403864185 859 | epoch: 35, train loss: 0.07165893557525817 860 | epoch: 36, train loss: 0.06511085444972628 861 | epoch: 37, train loss: 0.06631609684388552 862 | epoch: 38, train loss: 0.06559606557268471 863 | epoch: 39, train loss: 0.0643584192065256 864 | epoch: 40, train loss: 0.06785663627531557 865 | epoch: 41, train loss: 0.06456734149140261 866 | epoch: 42, train loss: 0.0665020855321061 867 | epoch: 43, train loss: 0.07127233273127959 868 | epoch: 44, train loss: 0.06657074280970153 869 | epoch: 45, train loss: 0.0650257889979652 870 | epoch: 46, train loss: 0.06274731202228438 871 | epoch: 47, train loss: 0.06283282459757868 872 | epoch: 48, train loss: 0.06490637047127598 873 | epoch: 49, train loss: 0.06395778771755951 874 | epoch: 50, train loss: 0.06219145704415582 875 | epoch: 51, train loss: 0.061704386446979786 876 | epoch: 52, train loss: 0.06212229447971497 877 | epoch: 53, train loss: 0.06263601285449806 878 | epoch: 54, train loss: 0.062418830736229815 879 | epoch: 55, train loss: 0.06028208580045473 880 | epoch: 56, train loss: 0.05999947982352404 881 | epoch: 57, train loss: 0.06267681396344588 882 | epoch: 58, train loss: 0.06254866710376172 883 | epoch: 59, train loss: 0.061833995017444805 884 | epoch: 60, train loss: 0.06288228399075922 885 | epoch: 61, train loss: 0.060699197818480786 886 | epoch: 62, train loss: 0.058909026196315176 887 | epoch: 63, train loss: 0.06073362994495602 888 | epoch: 64, train loss: 0.061914546481732814 889 | epoch: 65, train loss: 0.06269714413654237 890 | epoch: 66, train loss: 0.05916253896430135 891 | epoch: 67, train loss: 0.0608932884365675 892 | epoch: 68, train loss: 0.05897937391308092 893 | epoch: 69, train loss: 0.05898021722567223 894 | epoch: 70, train loss: 0.058693850612533946 895 | epoch: 71, train loss: 0.05917047591702569 896 | epoch: 72, train loss: 0.058193922663728394 897 | epoch: 73, train loss: 0.05803509145265534 898 | epoch: 74, train loss: 0.06624723348899611 899 | epoch: 75, train loss: 0.06810145658840026 900 | epoch: 76, train loss: 0.07130376307205075 901 | epoch: 77, train loss: 0.06355864998130571 902 | epoch: 78, train loss: 0.06018848752691632 903 | epoch: 79, train loss: 0.060795812047131005 904 | epoch: 80, train loss: 0.058888648658813464 905 | epoch: 81, train loss: 0.059340994095518476 906 | epoch: 82, train loss: 0.058033558495697524 907 | epoch: 83, train loss: 0.05628838962210076 908 | epoch: 84, train loss: 0.05560901345285986 909 | epoch: 85, train loss: 0.05522337808672871 910 | epoch: 86, train loss: 0.05710372155798333 911 | epoch: 87, train loss: 0.055660449677989596 912 | epoch: 88, train loss: 0.05463033351337626 913 | epoch: 89, train loss: 0.05442363065889194 914 | epoch: 90, train loss: 0.0577563338114747 915 | epoch: 91, train loss: 0.05497693113007006 916 | epoch: 92, train loss: 0.05444350419566035 917 | epoch: 93, train loss: 0.05522836857874479 918 | epoch: 94, train loss: 0.05465114624461248 919 | epoch: 95, train loss: 0.05237867465863625 920 | epoch: 96, train loss: 0.05413480597503838 921 | epoch: 97, train loss: 0.05568463867530227 922 | epoch: 98, train loss: 0.053137250510709624 923 | epoch: 99, train loss: 0.05317651635656754 924 | epoch: 1, train loss: 0.5941495483829862 925 | epoch: 2, train loss: 0.3667884277445929 926 | epoch: 3, train loss: 0.29728695643799646 927 | epoch: 4, train loss: 0.22445833683013916 928 | epoch: 5, train loss: 0.18188504005471864 929 | epoch: 6, train loss: 0.16683986641111828 930 | epoch: 7, train loss: 0.13809215490307128 931 | epoch: 8, train loss: 0.11718751348200299 932 | epoch: 9, train loss: 0.11859201976940745 933 | epoch: 10, train loss: 0.10627831518650055 934 | epoch: 11, train loss: 0.10241484251760301 935 | epoch: 12, train loss: 0.09309287049940654 936 | epoch: 13, train loss: 0.09046624317055657 937 | epoch: 14, train loss: 0.08708651417068072 938 | epoch: 15, train loss: 0.0834070240990037 939 | epoch: 16, train loss: 0.08452448878614675 940 | epoch: 17, train loss: 0.08042745398623603 941 | epoch: 18, train loss: 0.07780462476824011 942 | epoch: 19, train loss: 0.08083256946078368 943 | epoch: 20, train loss: 0.08058207925586473 944 | epoch: 21, train loss: 0.07778725134474891 945 | epoch: 22, train loss: 0.07650562668485301 946 | epoch: 23, train loss: 0.073620008038623 947 | epoch: 24, train loss: 0.071189781916993 948 | epoch: 25, train loss: 0.07264291761176926 949 | epoch: 26, train loss: 0.07017774631579717 950 | epoch: 27, train loss: 0.07065729025219168 951 | epoch: 28, train loss: 0.07034068545770078 952 | epoch: 29, train loss: 0.06695743064795222 953 | epoch: 30, train loss: 0.06939043539265792 954 | epoch: 31, train loss: 0.06787205975325335 955 | epoch: 32, train loss: 0.06569787434169225 956 | epoch: 33, train loss: 0.06501557607026327 957 | epoch: 34, train loss: 0.06415718137508347 958 | epoch: 35, train loss: 0.07016832931410699 959 | epoch: 36, train loss: 0.06890045722857827 960 | epoch: 37, train loss: 0.069160179545482 961 | epoch: 38, train loss: 0.06536990554914587 962 | epoch: 39, train loss: 0.06510624750739052 963 | epoch: 40, train loss: 0.0647027785224574 964 | epoch: 41, train loss: 0.06570216543262913 965 | epoch: 42, train loss: 0.06209978425786609 966 | epoch: 43, train loss: 0.06268693267234735 967 | epoch: 44, train loss: 0.06873470518205847 968 | epoch: 45, train loss: 0.07090265973515454 969 | epoch: 46, train loss: 0.06469226211664222 970 | epoch: 47, train loss: 0.06396718988461154 971 | epoch: 48, train loss: 0.06562187012639784 972 | epoch: 49, train loss: 0.061425994017294476 973 | epoch: 50, train loss: 0.06327343732118607 974 | epoch: 51, train loss: 0.06260878264549233 975 | epoch: 52, train loss: 0.06078485177741164 976 | epoch: 53, train loss: 0.06146365896399532 977 | epoch: 54, train loss: 0.06057103483804634 978 | epoch: 55, train loss: 0.06106211484542915 979 | epoch: 56, train loss: 0.060569748193735166 980 | epoch: 57, train loss: 0.07118220751484235 981 | epoch: 58, train loss: 0.07009587951359295 982 | epoch: 59, train loss: 0.06330401343958718 983 | epoch: 60, train loss: 0.06354105725352253 984 | epoch: 61, train loss: 0.0612327755384502 985 | epoch: 62, train loss: 0.060450187928619836 986 | epoch: 63, train loss: 0.05850700315620218 987 | epoch: 64, train loss: 0.058409513550854865 988 | epoch: 65, train loss: 0.060791550958085624 989 | epoch: 66, train loss: 0.05993979823376451 990 | epoch: 67, train loss: 0.05810497053677127 991 | epoch: 68, train loss: 0.059137832994262375 992 | epoch: 69, train loss: 0.05742869623714969 993 | epoch: 70, train loss: 0.05879728584772065 994 | epoch: 71, train loss: 0.05840823768327633 995 | epoch: 72, train loss: 0.05772672247673784 996 | epoch: 73, train loss: 0.056660872928443407 997 | epoch: 74, train loss: 0.057248812479277454 998 | epoch: 75, train loss: 0.056507262002144544 999 | epoch: 76, train loss: 0.05839865628097739 1000 | epoch: 77, train loss: 0.06587701840769677 1001 | epoch: 78, train loss: 0.05979168361851147 1002 | epoch: 79, train loss: 0.057634737032155194 1003 | epoch: 80, train loss: 0.059386420994997025 1004 | epoch: 81, train loss: 0.06301286789987769 1005 | epoch: 82, train loss: 0.05989566773530983 1006 | epoch: 83, train loss: 0.05993918427044437 1007 | epoch: 84, train loss: 0.05751245060846919 1008 | epoch: 85, train loss: 0.05544612295038644 1009 | epoch: 86, train loss: 0.05664837249510345 1010 | epoch: 87, train loss: 0.0571145187797291 1011 | epoch: 88, train loss: 0.05492097250230256 1012 | epoch: 89, train loss: 0.0555287760105871 1013 | epoch: 90, train loss: 0.056578707082995346 1014 | epoch: 91, train loss: 0.05510789909887882 1015 | epoch: 92, train loss: 0.05554366253671192 1016 | epoch: 93, train loss: 0.05409663392319566 1017 | epoch: 94, train loss: 0.054384895910819374 1018 | epoch: 95, train loss: 0.05347696589749484 1019 | epoch: 96, train loss: 0.0545265967292445 1020 | epoch: 97, train loss: 0.05482120270885173 1021 | epoch: 98, train loss: 0.05353864175932748 1022 | epoch: 99, train loss: 0.05321143350253502 1023 | epoch: 1, train loss: 0.28403054142282125 1024 | epoch: 2, train loss: 0.13974222115107945 1025 | epoch: 3, train loss: 0.12659883960371926 1026 | epoch: 4, train loss: 0.11807466316081229 1027 | epoch: 1, train loss: 0.6101753684607419 1028 | epoch: 2, train loss: 0.3564839688214389 1029 | epoch: 3, train loss: 0.30562718348069623 1030 | epoch: 4, train loss: 0.2556102912534367 1031 | epoch: 5, train loss: 0.2219574126330289 1032 | epoch: 6, train loss: 0.20165764879096637 1033 | epoch: 7, train loss: 0.18495498191226611 1034 | epoch: 8, train loss: 0.1641310927542773 1035 | epoch: 9, train loss: 0.14764011177149686 1036 | epoch: 10, train loss: 0.16208295253190128 1037 | epoch: 11, train loss: 0.14655376564372669 1038 | epoch: 12, train loss: 0.127106520939957 1039 | epoch: 13, train loss: 0.12054856324737723 1040 | epoch: 14, train loss: 0.11076663840900768 1041 | epoch: 15, train loss: 0.10841174626892264 1042 | epoch: 16, train loss: 0.104097342626615 1043 | epoch: 17, train loss: 0.09969358146190643 1044 | epoch: 18, train loss: 0.09725810384208505 1045 | epoch: 19, train loss: 0.09479079734195363 1046 | epoch: 20, train loss: 0.08939473534172232 1047 | epoch: 21, train loss: 0.08882232484492389 1048 | epoch: 22, train loss: 0.0865879709070379 1049 | epoch: 23, train loss: 0.08486979387023232 1050 | epoch: 24, train loss: 0.08424380828033794 1051 | epoch: 25, train loss: 0.08491134101694281 1052 | epoch: 26, train loss: 0.08198995888233185 1053 | epoch: 27, train loss: 0.07903746312314813 1054 | epoch: 28, train loss: 0.07908666878938675 1055 | epoch: 29, train loss: 0.07739397206089714 1056 | epoch: 30, train loss: 0.07638539848002521 1057 | epoch: 31, train loss: 0.07850584320046684 1058 | epoch: 32, train loss: 0.0773794102397832 1059 | epoch: 33, train loss: 0.07882856306704608 1060 | epoch: 34, train loss: 0.07250491766767068 1061 | epoch: 35, train loss: 0.07068298316814682 1062 | epoch: 36, train loss: 0.0715364661406387 1063 | epoch: 37, train loss: 0.07108580117875879 1064 | epoch: 38, train loss: 0.0704498602585359 1065 | epoch: 39, train loss: 0.06985536827282472 1066 | epoch: 40, train loss: 0.06914959610863165 1067 | epoch: 41, train loss: 0.06779053434729576 1068 | epoch: 42, train loss: 0.06850200248035518 1069 | epoch: 43, train loss: 0.06884505362673239 1070 | epoch: 44, train loss: 0.06689381430094893 1071 | epoch: 45, train loss: 0.06806272505359216 1072 | epoch: 46, train loss: 0.0658667314458977 1073 | epoch: 47, train loss: 0.06483228978785602 1074 | epoch: 48, train loss: 0.06481768123128197 1075 | epoch: 49, train loss: 0.06471413746476173 1076 | epoch: 50, train loss: 0.06494498151269826 1077 | epoch: 51, train loss: 0.06654572690075095 1078 | epoch: 52, train loss: 0.06524208831516179 1079 | epoch: 53, train loss: 0.06452342258258299 1080 | epoch: 54, train loss: 0.06511116671291264 1081 | epoch: 55, train loss: 0.0627700002356009 1082 | epoch: 56, train loss: 0.06539819220250304 1083 | epoch: 57, train loss: 0.06403743069280278 1084 | epoch: 58, train loss: 0.06238444759087129 1085 | epoch: 59, train loss: 0.06295610456304117 1086 | epoch: 60, train loss: 0.062195501205596054 1087 | epoch: 61, train loss: 0.06339297138831833 1088 | epoch: 62, train loss: 0.06209187209606171 1089 | epoch: 63, train loss: 0.059824046763506805 1090 | epoch: 64, train loss: 0.06127727302637967 1091 | epoch: 65, train loss: 0.05966144563122229 1092 | epoch: 66, train loss: 0.060442881150679154 1093 | epoch: 67, train loss: 0.06116121126846834 1094 | epoch: 68, train loss: 0.07139951871200041 1095 | epoch: 69, train loss: 0.0663998970253901 1096 | epoch: 70, train loss: 0.06459799307313832 1097 | epoch: 71, train loss: 0.06426923620429906 1098 | epoch: 72, train loss: 0.06137757748365402 1099 | epoch: 73, train loss: 0.06281528283249248 1100 | epoch: 74, train loss: 0.06039875948970968 1101 | epoch: 75, train loss: 0.06050486253066496 1102 | epoch: 76, train loss: 0.06113762002099644 1103 | epoch: 77, train loss: 0.05993418260054155 1104 | epoch: 78, train loss: 0.05848369645801457 1105 | epoch: 79, train loss: 0.05995704199780117 1106 | epoch: 80, train loss: 0.059035287323323166 1107 | epoch: 81, train loss: 0.05856479501182383 1108 | epoch: 82, train loss: 0.05818232013420625 1109 | epoch: 83, train loss: 0.05771731280467727 1110 | epoch: 84, train loss: 0.05677052147009156 1111 | epoch: 85, train loss: 0.058797004209323364 1112 | epoch: 86, train loss: 0.05856989256360314 1113 | epoch: 87, train loss: 0.057263560593128204 1114 | epoch: 88, train loss: 0.057637367736209526 1115 | epoch: 89, train loss: 0.058894762261347336 1116 | epoch: 90, train loss: 0.0600922785022042 1117 | epoch: 91, train loss: 0.05820823833346367 1118 | epoch: 92, train loss: 0.05786852166056633 1119 | epoch: 93, train loss: 0.056362480941143905 1120 | epoch: 94, train loss: 0.05624002489176663 1121 | epoch: 95, train loss: 0.05687232247807763 1122 | epoch: 96, train loss: 0.0553965074094859 1123 | epoch: 97, train loss: 0.05558276650580493 1124 | epoch: 98, train loss: 0.05522116273641586 1125 | epoch: 99, train loss: 0.055639348924160004 1126 | epoch: 100, train loss: 0.055862909690900284 1127 | epoch: 101, train loss: 0.056899292902512985 1128 | epoch: 102, train loss: 0.05532975359396501 1129 | epoch: 103, train loss: 0.05472278628836979 1130 | epoch: 104, train loss: 0.053475059907544746 1131 | epoch: 105, train loss: 0.05487150563435121 1132 | epoch: 106, train loss: 0.05292389541864395 1133 | epoch: 107, train loss: 0.05355891009623354 1134 | epoch: 108, train loss: 0.0568819482895461 1135 | epoch: 109, train loss: 0.05595102601430633 1136 | epoch: 110, train loss: 0.054756989194588226 1137 | epoch: 111, train loss: 0.05340138653462583 1138 | epoch: 112, train loss: 0.05231737040660598 1139 | epoch: 113, train loss: 0.05512223460457542 1140 | epoch: 114, train loss: 0.054171538149768654 1141 | epoch: 115, train loss: 0.05426797372373668 1142 | epoch: 116, train loss: 0.0550920120017095 1143 | epoch: 117, train loss: 0.0543502765623006 1144 | epoch: 118, train loss: 0.05473875863985582 1145 | epoch: 119, train loss: 0.052973354404622856 1146 | epoch: 120, train loss: 0.05479529398408803 1147 | epoch: 121, train loss: 0.05794807963750579 1148 | epoch: 122, train loss: 0.05754873529076576 1149 | epoch: 123, train loss: 0.05329101634296504 1150 | epoch: 124, train loss: 0.05547985874793746 1151 | epoch: 125, train loss: 0.05271359702402895 1152 | epoch: 126, train loss: 0.05253685706041076 1153 | epoch: 127, train loss: 0.05437989431348714 1154 | epoch: 128, train loss: 0.05193991755897349 1155 | epoch: 129, train loss: 0.050823495130647316 1156 | epoch: 130, train loss: 0.052761507643894714 1157 | epoch: 131, train loss: 0.0538366362452507 1158 | epoch: 132, train loss: 0.05371182682839307 1159 | epoch: 133, train loss: 0.05196439440954815 1160 | epoch: 134, train loss: 0.051621024242856285 1161 | epoch: 135, train loss: 0.04959283565933054 1162 | epoch: 136, train loss: 0.05052345178344033 1163 | epoch: 137, train loss: 0.05130480568517338 1164 | epoch: 138, train loss: 0.051159795712340965 1165 | epoch: 139, train loss: 0.05082153935324062 1166 | epoch: 140, train loss: 0.049924470484256744 1167 | epoch: 141, train loss: 0.04866128618066961 1168 | epoch: 142, train loss: 0.048158438368277115 1169 | epoch: 143, train loss: 0.0495351105928421 1170 | epoch: 144, train loss: 0.05142061547799544 1171 | epoch: 145, train loss: 0.0492656755853783 1172 | epoch: 146, train loss: 0.049103630875999275 1173 | epoch: 147, train loss: 0.049287565391172065 1174 | epoch: 148, train loss: 0.04774347421797839 1175 | epoch: 149, train loss: 0.048044307326728646 1176 | epoch: 1, train loss: 0.3576217648528871 1177 | epoch: 2, train loss: 0.19694486686161586 1178 | epoch: 3, train loss: 0.13485854793162572 1179 | epoch: 4, train loss: 0.12479584522190548 1180 | epoch: 5, train loss: 0.10090368950650805 1181 | epoch: 6, train loss: 0.08966896150793348 1182 | epoch: 7, train loss: 0.08704187330745515 1183 | epoch: 8, train loss: 0.08639521009865261 1184 | epoch: 9, train loss: 0.09292344287747428 1185 | epoch: 10, train loss: 0.08972003169002987 1186 | epoch: 11, train loss: 0.08028945025233995 1187 | epoch: 12, train loss: 0.07436548208906538 1188 | epoch: 13, train loss: 0.07301163620182446 1189 | epoch: 14, train loss: 0.07366513638269334 1190 | epoch: 15, train loss: 0.07142292087276776 1191 | epoch: 16, train loss: 0.06927433130996567 1192 | epoch: 17, train loss: 0.07202341496234849 1193 | epoch: 18, train loss: 0.06705349932114284 1194 | epoch: 19, train loss: 0.06770478863091696 1195 | epoch: 20, train loss: 0.06760141271210852 1196 | epoch: 21, train loss: 0.06934185538973127 1197 | epoch: 22, train loss: 0.06577856874182111 1198 | epoch: 23, train loss: 0.0676331716988768 1199 | epoch: 24, train loss: 0.06859728339172545 1200 | epoch: 25, train loss: 0.06593126111796924 1201 | epoch: 26, train loss: 0.0639015260551657 1202 | epoch: 27, train loss: 0.06996606308079902 1203 | epoch: 28, train loss: 0.0672030280388537 1204 | epoch: 29, train loss: 0.06830563130123275 1205 | epoch: 30, train loss: 0.06615219141046207 1206 | epoch: 31, train loss: 0.06344638019800186 1207 | epoch: 32, train loss: 0.06440130603455362 1208 | epoch: 33, train loss: 0.06381899632868313 1209 | epoch: 34, train loss: 0.06411203369498253 1210 | epoch: 35, train loss: 0.06355187988706998 1211 | epoch: 36, train loss: 0.06872960073607308 1212 | epoch: 37, train loss: 0.0704853351981867 1213 | epoch: 38, train loss: 0.06556968976344381 1214 | epoch: 39, train loss: 0.06329683782089324 1215 | epoch: 40, train loss: 0.06296881110895247 1216 | epoch: 41, train loss: 0.06375821129906745 1217 | epoch: 42, train loss: 0.0616921613968554 1218 | epoch: 43, train loss: 0.0652287472926435 1219 | epoch: 44, train loss: 0.06203204606260572 1220 | epoch: 45, train loss: 0.06048317200371197 1221 | epoch: 46, train loss: 0.06354295107580367 1222 | epoch: 47, train loss: 0.06231904952299027 1223 | epoch: 48, train loss: 0.061328494655234475 1224 | epoch: 49, train loss: 0.06035503816036951 1225 | epoch: 50, train loss: 0.059758782209385006 1226 | epoch: 51, train loss: 0.059534592287881036 1227 | epoch: 52, train loss: 0.060229163084711344 1228 | epoch: 53, train loss: 0.06157733624180158 1229 | epoch: 54, train loss: 0.062134930243094764 1230 | epoch: 55, train loss: 0.06097933064614024 1231 | epoch: 56, train loss: 0.05978413298726082 1232 | epoch: 57, train loss: 0.059182744118429366 1233 | epoch: 58, train loss: 0.06105768077430271 1234 | epoch: 59, train loss: 0.0612194568273567 1235 | epoch: 60, train loss: 0.05993558732526643 1236 | epoch: 61, train loss: 0.059573265945627576 1237 | epoch: 62, train loss: 0.05930838077550843 1238 | epoch: 63, train loss: 0.058565804113944374 1239 | epoch: 64, train loss: 0.058464641194968 1240 | epoch: 65, train loss: 0.05797501440559115 1241 | epoch: 66, train loss: 0.058110394825538 1242 | epoch: 67, train loss: 0.060244703044493995 1243 | epoch: 68, train loss: 0.06211343691462562 1244 | epoch: 69, train loss: 0.061661472987561 1245 | epoch: 70, train loss: 0.05923799549539884 1246 | epoch: 71, train loss: 0.06195896047921408 1247 | epoch: 72, train loss: 0.05917967572098687 1248 | epoch: 73, train loss: 0.0564689047279812 1249 | epoch: 74, train loss: 0.0588022679800079 1250 | epoch: 75, train loss: 0.05958632787778264 1251 | epoch: 76, train loss: 0.05752028773228327 1252 | epoch: 77, train loss: 0.058797381818294525 1253 | epoch: 78, train loss: 0.0573489364413988 1254 | epoch: 79, train loss: 0.05955373566775095 1255 | epoch: 80, train loss: 0.05654046879637809 1256 | epoch: 81, train loss: 0.056560316788298745 1257 | epoch: 82, train loss: 0.057611440264043356 1258 | epoch: 83, train loss: 0.05950614171368735 1259 | epoch: 84, train loss: 0.06317316793969699 1260 | epoch: 85, train loss: 0.057698712462470644 1261 | epoch: 86, train loss: 0.05809514969587326 1262 | epoch: 87, train loss: 0.05636211874939147 1263 | epoch: 88, train loss: 0.05562863516665641 1264 | epoch: 89, train loss: 0.05683971551202592 1265 | epoch: 90, train loss: 0.057308328648408256 1266 | epoch: 91, train loss: 0.056705598142885026 1267 | epoch: 92, train loss: 0.05612527508111227 1268 | epoch: 93, train loss: 0.05683749825471923 1269 | epoch: 94, train loss: 0.056921850180342085 1270 | epoch: 95, train loss: 0.055374675386008765 1271 | epoch: 96, train loss: 0.055180130792515616 1272 | epoch: 97, train loss: 0.05432138946794328 1273 | epoch: 98, train loss: 0.05593647240173249 1274 | epoch: 99, train loss: 0.0558178573846817 1275 | epoch: 100, train loss: 0.05562879659590267 1276 | epoch: 101, train loss: 0.05494329723573867 1277 | epoch: 102, train loss: 0.05623596835704077 1278 | epoch: 103, train loss: 0.05572642980232125 1279 | epoch: 104, train loss: 0.06033795468863987 1280 | epoch: 105, train loss: 0.05680939768041883 1281 | epoch: 106, train loss: 0.05399674123951367 1282 | epoch: 107, train loss: 0.05391042768245652 1283 | epoch: 108, train loss: 0.05425775139814332 1284 | epoch: 109, train loss: 0.05482104154569762 1285 | epoch: 110, train loss: 0.05507248356228783 1286 | epoch: 111, train loss: 0.06120334867210615 1287 | epoch: 112, train loss: 0.05846887862398511 1288 | epoch: 113, train loss: 0.056076522739160625 1289 | epoch: 114, train loss: 0.05447333465729441 1290 | epoch: 115, train loss: 0.05512903808128266 1291 | epoch: 116, train loss: 0.054862640088512785 1292 | epoch: 117, train loss: 0.057510788419416974 1293 | epoch: 118, train loss: 0.05518257777605738 1294 | epoch: 119, train loss: 0.05404771846674737 1295 | epoch: 120, train loss: 0.054382767705690296 1296 | epoch: 121, train loss: 0.05666029879025051 1297 | epoch: 122, train loss: 0.05466393611970402 1298 | epoch: 123, train loss: 0.05338888029967036 1299 | epoch: 124, train loss: 0.05507822778253328 1300 | epoch: 125, train loss: 0.05351167509243602 1301 | epoch: 126, train loss: 0.05389995340790067 1302 | epoch: 127, train loss: 0.054812408096733545 1303 | epoch: 128, train loss: 0.0540713451447941 1304 | epoch: 129, train loss: 0.052734143854606716 1305 | epoch: 130, train loss: 0.05313122645020485 1306 | epoch: 131, train loss: 0.05371000546784628 1307 | epoch: 132, train loss: 0.05379836527364595 1308 | epoch: 133, train loss: 0.05357941665819713 1309 | epoch: 134, train loss: 0.05299951189330646 1310 | epoch: 135, train loss: 0.052875928580760956 1311 | epoch: 136, train loss: 0.05317306216983568 1312 | epoch: 137, train loss: 0.05342496204234305 1313 | epoch: 138, train loss: 0.05324836713927133 1314 | epoch: 139, train loss: 0.050817567677724926 1315 | epoch: 140, train loss: 0.05199626709024111 1316 | epoch: 141, train loss: 0.051051844975778034 1317 | epoch: 142, train loss: 0.05155260630306743 1318 | epoch: 143, train loss: 0.05227173155262357 1319 | epoch: 144, train loss: 0.052376916365964074 1320 | epoch: 145, train loss: 0.05149342101954278 1321 | epoch: 146, train loss: 0.054154952544541585 1322 | epoch: 147, train loss: 0.05423387246472495 1323 | epoch: 148, train loss: 0.05638335183972404 1324 | epoch: 149, train loss: 0.056777234233561014 1325 | epoch: 150, train loss: 0.055373086993183405 1326 | epoch: 151, train loss: 0.05456227614056496 1327 | epoch: 152, train loss: 0.05452814999790419 1328 | epoch: 153, train loss: 0.055792558051290964 1329 | epoch: 154, train loss: 0.05362413060807046 1330 | epoch: 155, train loss: 0.052476268439065846 1331 | epoch: 156, train loss: 0.05453459137961978 1332 | epoch: 157, train loss: 0.053781277721836454 1333 | epoch: 158, train loss: 0.05156948098114559 1334 | epoch: 159, train loss: 0.052273347264244444 1335 | epoch: 160, train loss: 0.05047128296324185 1336 | epoch: 161, train loss: 0.06642305318798337 1337 | epoch: 162, train loss: 0.06559817918709346 1338 | epoch: 163, train loss: 0.05889031539360682 1339 | epoch: 164, train loss: 0.056687554788021816 1340 | epoch: 165, train loss: 0.0554568214075906 1341 | epoch: 166, train loss: 0.05338937345714796 1342 | epoch: 167, train loss: 0.05210849660493079 1343 | epoch: 168, train loss: 0.051238762658266795 1344 | epoch: 169, train loss: 0.051222966895216986 1345 | epoch: 170, train loss: 0.051795351185968945 1346 | epoch: 171, train loss: 0.05225818639709836 1347 | epoch: 172, train loss: 0.054356725265582405 1348 | epoch: 173, train loss: 0.05296634971385911 1349 | epoch: 174, train loss: 0.052210201642342975 1350 | epoch: 175, train loss: 0.05274832585737819 1351 | epoch: 176, train loss: 0.050579723148118885 1352 | epoch: 177, train loss: 0.051733998847859244 1353 | epoch: 178, train loss: 0.05056388942258699 1354 | epoch: 179, train loss: 0.05122878526647886 1355 | epoch: 180, train loss: 0.049827019373575844 1356 | epoch: 181, train loss: 0.050073687164556416 1357 | epoch: 182, train loss: 0.05053699389100075 1358 | epoch: 183, train loss: 0.05099563939230783 1359 | epoch: 184, train loss: 0.048637845331714266 1360 | epoch: 185, train loss: 0.05016545507879484 1361 | epoch: 186, train loss: 0.052367208436841055 1362 | epoch: 187, train loss: 0.05171095491165206 1363 | epoch: 188, train loss: 0.051517813688232786 1364 | epoch: 189, train loss: 0.05167661323433831 1365 | epoch: 190, train loss: 0.05172619099418322 1366 | epoch: 191, train loss: 0.04892781199443908 1367 | epoch: 192, train loss: 0.04762096596615655 1368 | epoch: 193, train loss: 0.04878771855008034 1369 | epoch: 194, train loss: 0.04986766956391789 1370 | epoch: 195, train loss: 0.05437280290893146 1371 | epoch: 196, train loss: 0.055148212505238395 1372 | epoch: 197, train loss: 0.05143432940045992 1373 | epoch: 198, train loss: 0.04949059337377548 1374 | epoch: 199, train loss: 0.04817831037299974 1375 | epoch: 1, train loss: 1376 | 0.2938 1377 | [torch.cuda.FloatTensor of size () (GPU 0)] 1378 | 1379 | epoch: 1, train loss: 1380 | 0.2868 1381 | [torch.cuda.FloatTensor of size () (GPU 0)] 1382 | 1383 | epoch: 1, train loss: 1384 | 0.2903 1385 | [torch.cuda.FloatTensor of size () (GPU 0)] 1386 | 1387 | epoch: 1, train loss: 1388 | 0.4778 1389 | [torch.cuda.FloatTensor of size () (GPU 0)] 1390 | 1391 | epoch: 1, train loss: 1392 | 0.3351 1393 | [torch.cuda.FloatTensor of size () (GPU 0)] 1394 | 1395 | epoch: 1, train loss: 1396 | 0.3206 1397 | [torch.cuda.FloatTensor of size () (GPU 0)] 1398 | 1399 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | """ 2 | UNet 3 | The main UNet model implementation 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | # Utility Functions 12 | ''' when filter kernel= 3x3, padding=1 makes in&out matrix same size''' 13 | def conv_bn_leru(in_channels, out_channels, kernel_size=3, stride=1, padding=1): 14 | return nn.Sequential( 15 | nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding), 16 | nn.BatchNorm2d(out_channels), 17 | nn.ReLU(inplace=True), 18 | nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding), 19 | nn.BatchNorm2d(out_channels), 20 | nn.ReLU(inplace=True), 21 | ) 22 | 23 | def down_pooling(): 24 | return nn.MaxPool2d(2) 25 | 26 | def up_pooling(in_channels, out_channels, kernel_size=2, stride=2): 27 | return nn.Sequential( 28 | nn.ConvTranspose2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride), 29 | nn.BatchNorm2d(out_channels), 30 | nn.ReLU(inplace=True) 31 | ) 32 | 33 | # UNet class 34 | 35 | class UNet(nn.Module): 36 | def __init__(self, input_channels, nclasses): 37 | super().__init__() 38 | # go down 39 | self.conv1 = conv_bn_leru(input_channels,64) 40 | self.conv2 = conv_bn_leru(64, 128) 41 | self.conv3 = conv_bn_leru(128, 256) 42 | self.conv4 = conv_bn_leru(256, 512) 43 | self.conv5 = conv_bn_leru(512, 1024) 44 | self.down_pooling = nn.MaxPool2d(2) 45 | 46 | # go up 47 | self.up_pool6 = up_pooling(1024, 512) 48 | self.conv6 = conv_bn_leru(1024, 512) 49 | self.up_pool7 = up_pooling(512, 256) 50 | self.conv7 = conv_bn_leru(512, 256) 51 | self.up_pool8 = up_pooling(256, 128) 52 | self.conv8 = conv_bn_leru(256, 128) 53 | self.up_pool9 = up_pooling(128, 64) 54 | self.conv9 = conv_bn_leru(128, 64) 55 | 56 | self.conv10 = nn.Conv2d(64, nclasses, 1) 57 | 58 | 59 | # test weight init 60 | for m in self.modules(): 61 | if isinstance(m, nn.Conv2d): 62 | nn.init.kaiming_normal(m.weight.data, a=0, mode='fan_out') 63 | if m.bias is not None: 64 | m.bias.data.zero_() 65 | 66 | 67 | def forward(self, x): 68 | # normalize input data 69 | x = x/255. 70 | # go down 71 | x1 = self.conv1(x) 72 | p1 = self.down_pooling(x1) 73 | x2 = self.conv2(p1) 74 | p2 = self.down_pooling(x2) 75 | x3 = self.conv3(p2) 76 | p3 = self.down_pooling(x3) 77 | x4 = self.conv4(p3) 78 | p4 = self.down_pooling(x4) 79 | x5 = self.conv5(p4) 80 | 81 | # go up 82 | p6 = self.up_pool6(x5) 83 | x6 = torch.cat([p6, x4], dim=1) 84 | x6 = self.conv6(x6) 85 | 86 | p7 = self.up_pool7(x6) 87 | x7 = torch.cat([p7, x3], dim=1) 88 | x7 = self.conv7(x7) 89 | 90 | p8 = self.up_pool8(x7) 91 | x8 = torch.cat([p8, x2], dim=1) 92 | x8 = self.conv8(x8) 93 | 94 | p9 = self.up_pool9(x8) 95 | x9 = torch.cat([p9, x1], dim=1) 96 | x9 = self.conv9(x9) 97 | 98 | 99 | output = self.conv10(x9) 100 | output = F.sigmoid(output) 101 | 102 | return output 103 | 104 | class UNet2(nn.Module): 105 | def __init__(self, input_channels, nclasses): 106 | super().__init__() 107 | # go down 108 | self.conv1 = conv_bn_leru(input_channels,16) 109 | self.conv2 = conv_bn_leru(16, 32) 110 | self.conv3 = conv_bn_leru(32, 64) 111 | self.conv4 = conv_bn_leru(64, 128) 112 | self.conv5 = conv_bn_leru(128, 256) 113 | self.down_pooling = nn.MaxPool2d(2) 114 | 115 | # go up 116 | self.up_pool6 = up_pooling(256, 128) 117 | self.conv6 = conv_bn_leru(256, 128) 118 | self.up_pool7 = up_pooling(128, 64) 119 | self.conv7 = conv_bn_leru(128, 64) 120 | self.up_pool8 = up_pooling(64, 32) 121 | self.conv8 = conv_bn_leru(64, 32) 122 | self.up_pool9 = up_pooling(32, 16) 123 | self.conv9 = conv_bn_leru(32, 16) 124 | 125 | self.conv10 = nn.Conv2d(16, nclasses, 1) 126 | 127 | 128 | # test weight init 129 | for m in self.modules(): 130 | if isinstance(m, nn.Conv2d): 131 | nn.init.kaiming_normal(m.weight.data, a=0, mode='fan_out') 132 | if m.bias is not None: 133 | m.bias.data.zero_() 134 | 135 | 136 | def forward(self, x): 137 | # normalize input data 138 | x = x/255. 139 | # go down 140 | x1 = self.conv1(x) 141 | p1 = self.down_pooling(x1) 142 | x2 = self.conv2(p1) 143 | p2 = self.down_pooling(x2) 144 | x3 = self.conv3(p2) 145 | p3 = self.down_pooling(x3) 146 | x4 = self.conv4(p3) 147 | p4 = self.down_pooling(x4) 148 | x5 = self.conv5(p4) 149 | 150 | # go up 151 | p6 = self.up_pool6(x5) 152 | x6 = torch.cat([p6, x4], dim=1) 153 | x6 = self.conv6(x6) 154 | 155 | p7 = self.up_pool7(x6) 156 | x7 = torch.cat([p7, x3], dim=1) 157 | x7 = self.conv7(x7) 158 | 159 | p8 = self.up_pool8(x7) 160 | x8 = torch.cat([p8, x2], dim=1) 161 | x8 = self.conv8(x8) 162 | 163 | p9 = self.up_pool9(x8) 164 | x9 = torch.cat([p9, x1], dim=1) 165 | x9 = self.conv9(x9) 166 | 167 | 168 | output = self.conv10(x9) 169 | output = F.sigmoid(output) 170 | 171 | return output 172 | 173 | -------------------------------------------------------------------------------- /plot_loss.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import matplotlib.pyplot as plt 3 | 4 | data = pd.read_csv('loss',sep=':', header=None) 5 | 6 | loss = data[2] 7 | 8 | plt.figure() 9 | plt.title('UNet training loss', fontsize=20) 10 | plt.xlabel('epoch', fontsize=15) 11 | plt.ylabel('loss', fontsize=15) 12 | plt.plot(loss, linewidth=2) 13 | plt.show() -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """ 2 | UNet 3 | Train Unet model 4 | """ 5 | import numpy as np 6 | import os 7 | import torch 8 | import torch.nn as nn 9 | import torch.optim as optim 10 | from torch.autograd import Variable 11 | from dataset import get_train_valid_loader, get_test_loader 12 | from model import UNet2 13 | from utils import Option, encode_and_save, compute_iou 14 | from skimage import io 15 | from skimage.transform import resize 16 | import matplotlib.pyplot as plt 17 | 18 | def train(model, train_loader, opt, criterion, epoch): 19 | model.train() 20 | num_batches = 0 21 | avg_loss = 0 22 | with open('logs.txt', 'a') as file: 23 | for batch_idx, sample_batched in enumerate(train_loader): 24 | data = sample_batched['image'] 25 | target = sample_batched['mask'] 26 | data, target = Variable(data.type(opt.dtype)), Variable(target.type(opt.dtype)) 27 | optimizer.zero_grad() 28 | output = model(data) 29 | # output = (output > 0.5).type(opt.dtype) # use more gpu memory, also, loss does not change if use this line 30 | loss = criterion(output, target) 31 | loss.backward() 32 | optimizer.step() 33 | avg_loss += loss.data[0] 34 | num_batches += 1 35 | avg_loss /= num_batches 36 | # avg_loss /= len(train_loader.dataset) 37 | print('epoch: ' + str(epoch) + ', train loss: ' + str(avg_loss)) 38 | file.write('epoch: ' + str(epoch) + ', train loss: ' + str(avg_loss) + '\n') 39 | 40 | def val(model, val_loader, opt, criterion, epoch): 41 | model.eval() 42 | num_batches = 0 43 | avg_loss = 0 44 | with open('logs.txt', 'a') as file: 45 | for batch_idx, sample_batched in enumerate(val_loader): 46 | data = sample_batched['image'] 47 | target = sample_batched['mask'] 48 | data, target = Variable(data.type(opt.dtype)), Variable(target.type(opt.dtype)) 49 | output = model.forward(data) 50 | # output = (output > 0.5).type(opt.dtype) # use more gpu memory, also, loss does not change if use this line 51 | loss = criterion(output, target) 52 | avg_loss += loss.data[0] 53 | num_batches += 1 54 | avg_loss /= num_batches 55 | # avg_loss /= len(val_loader.dataset) 56 | 57 | print('epoch: ' + str(epoch) + ', validation loss: ' + str(avg_loss)) 58 | file.write('epoch: ' + str(epoch) + ', validation loss: ' + str(avg_loss) + '\n') 59 | 60 | # train and validation 61 | def run(model, train_loader, val_loader, opt, criterion): 62 | for epoch in range(1, opt.epochs): 63 | train(model, train_loader, opt, criterion, epoch) 64 | val(model, val_loader, opt, criterion, epoch) 65 | 66 | # only train 67 | def run_train(model, train_loader, opt, criterion): 68 | for epoch in range(1, opt.epochs): 69 | train(model, train_loader, opt, criterion, epoch) 70 | 71 | # make prediction 72 | def run_test(model, test_loader, opt): 73 | """ 74 | predict the masks on testing set 75 | :param model: trained model 76 | :param test_loader: testing set 77 | :param opt: configurations 78 | :return: 79 | - predictions: list, for each elements, numpy array (Width, Height) 80 | - img_ids: list, for each elements, an image id string 81 | """ 82 | model.eval() 83 | predictions = [] 84 | img_ids = [] 85 | for batch_idx, sample_batched in enumerate(test_loader): 86 | data, img_id, height, width = sample_batched['image'], sample_batched['img_id'], sample_batched['height'], sample_batched['width'] 87 | data = Variable(data.type(opt.dtype)) 88 | output = model.forward(data) 89 | # output = (output > 0.5) 90 | output = output.data.cpu().numpy() 91 | output = output.transpose((0, 2, 3, 1)) # transpose to (B,H,W,C) 92 | for i in range(0,output.shape[0]): 93 | pred_mask = np.squeeze(output[i]) 94 | id = img_id[i] 95 | h = height[i] 96 | w = width[i] 97 | # in p219 the w and h above is int 98 | # in local the w and h above is LongTensor 99 | if not isinstance(h, int): 100 | h = h.cpu().numpy() 101 | w = w.cpu().numpy() 102 | pred_mask = resize(pred_mask, (h, w), mode='constant') 103 | pred_mask = (pred_mask > 0.5) 104 | predictions.append(pred_mask) 105 | img_ids.append(id) 106 | 107 | return predictions, img_ids 108 | 109 | if __name__ == '__main__': 110 | """Train Unet model""" 111 | opt = Option() 112 | model = UNet2(input_channels=3, nclasses=1) 113 | if opt.is_train: 114 | # split all data to train and validation, set split = True 115 | train_loader, val_loader = get_train_valid_loader(opt.root_dir, batch_size=opt.batch_size, 116 | split=True, shuffle=opt.shuffle, 117 | num_workers=opt.num_workers, 118 | val_ratio=0.1, pin_memory=opt.pin_memory) 119 | 120 | # load all data for training 121 | # train_loader = get_train_valid_loader(opt.root_dir, batch_size=opt.batch_size, 122 | # split=False, shuffle=opt.shuffle, 123 | # num_workers=opt.num_workers, 124 | # val_ratio=0.1, pin_memory=opt.pin_memory) 125 | if opt.n_gpu > 1: 126 | model = nn.DataParallel(model) 127 | if opt.is_cuda: 128 | model = model.cuda() 129 | optimizer = optim.Adam(model.parameters(), lr=opt.learning_rate, weight_decay=opt.weight_decay) 130 | criterion = nn.BCELoss().cuda() 131 | # start to run a training 132 | run_train(model, train_loader, opt, criterion) 133 | # make prediction on validation set 134 | predictions, img_ids = run_test(model, val_loader, opt) 135 | # compute IOU between prediction and ground truth masks 136 | compute_iou(predictions, img_ids, val_loader) 137 | # SAVE model 138 | if opt.save_model: 139 | torch.save(model.state_dict(), os.path.join(opt.checkpoint_dir, 'model-01.pt')) 140 | else: 141 | # load testing data for making predictions 142 | test_loader = get_test_loader(opt.test_dir, batch_size=opt.batch_size, shuffle=opt.shuffle, 143 | num_workers=opt.num_workers, pin_memory=opt.pin_memory) 144 | # load the model and run test 145 | model.load_state_dict(torch.load(os.path.join(opt.checkpoint_dir, 'model-01.pt'))) 146 | if opt.n_gpu > 1: 147 | model = nn.DataParallel(model) 148 | if opt.is_cuda: 149 | model = model.cuda() 150 | predictions, img_ids = run_test(model, test_loader, opt) 151 | # run length encoding and save as csv 152 | encode_and_save(predictions, img_ids) 153 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | UNet 3 | Common utility functions and classes 4 | """ 5 | 6 | import os 7 | import sys 8 | import numpy as np 9 | from tqdm import tqdm 10 | from skimage import io 11 | from PIL import Image 12 | import numpy as np 13 | import torch 14 | from skimage.morphology import label 15 | import pandas as pd 16 | import matplotlib.pylab as plt 17 | 18 | 19 | # Base Configuration class 20 | # Don't use this class directly. Instead, sub-class it and override 21 | 22 | class Config(): 23 | 24 | name = None 25 | 26 | img_width = 256 27 | img_height = 256 28 | 29 | img_channel = 3 30 | 31 | batch_size = 16 32 | 33 | learning_rate = 1e-3 34 | learning_momentum = 0.9 35 | weight_decay = 1e-4 36 | 37 | shuffle = False 38 | 39 | def __init__(self): 40 | self.IMAGE_SHAPE = np.array([ 41 | self.img_width, self.img_height, self.img_channel 42 | ]) 43 | 44 | def display(self): 45 | """Display Configuration values""" 46 | print("\nConfigurations:") 47 | for a in dir(self): 48 | if not a.startswith("__") and not callable(getattr(self, a)): 49 | print("{:30} {}".format(a, getattr(self, a))) 50 | print("\n") 51 | 52 | # Configurations 53 | 54 | class Option(Config): 55 | """Configuration for training on Kaggle Data Science Bowl 2018 56 | Derived from the base Config class and overrides specific values 57 | """ 58 | name = "DSB2018" 59 | 60 | # root dir of training and validation set 61 | root_dir = '/home/liming/Documents/dataset/dataScienceBowl2018/combined' 62 | 63 | # root dir of testing set 64 | test_dir = '/home/liming/Documents/dataset/dataScienceBowl2018/testing_data' 65 | 66 | # save segmenting results (prediction masks) to this folder 67 | results_dir = '/home/liming/Documents/dataset/dataScienceBowl2018/results' 68 | 69 | num_workers = 1 # number of threads for data loading 70 | shuffle = True # shuffle the data set 71 | batch_size = 16 # GTX1060 3G Memory 72 | epochs = 2 # number of epochs to train 73 | is_train = True # True for training, False for making prediction 74 | save_model = False # True for saving the model, False for not saving the model 75 | 76 | n_gpu = 1 # number of GPUs 77 | 78 | learning_rate = 1e-3 # learning rage 79 | weight_decay = 1e-4 # weight decay 80 | 81 | pin_memory = True # use pinned (page-locked) memory. when using CUDA, set to True 82 | 83 | is_cuda = torch.cuda.is_available() # True --> GPU 84 | num_gpus = torch.cuda.device_count() # number of GPUs 85 | checkpoint_dir = "./checkpoint" # dir to save checkpoints 86 | dtype = torch.cuda.FloatTensor if is_cuda else torch.Tensor # data type 87 | 88 | """ 89 | Dataset orgnization: 90 | Read images and masks, combine separated mask into one 91 | Write images and combined masks into specific folder 92 | """ 93 | class Utils(object): 94 | """ 95 | Initialize image parameters from DSB2018Config class 96 | """ 97 | def __init__(self, stage1_train_src, stage1_train_dest, stage1_test_src, stage1_test_dest): 98 | self.opt = Option 99 | self.stage1_train_src = stage1_train_src 100 | self.stage1_train_dest = stage1_train_dest 101 | self.stage1_test_src = stage1_test_src 102 | self.stage1_test_dest = stage1_test_dest 103 | 104 | # Combine all separated masks into one mask 105 | def assemble_masks(self, path): 106 | # mask = np.zeros((self.config.IMG_HEIGHT, self.config.IMG_WIDTH), dtype=np.uint8) 107 | mask = None 108 | for i, mask_file in enumerate(next(os.walk(os.path.join(path, 'masks')))[2]): 109 | mask_ = Image.open(os.path.join(path, 'masks', mask_file)).convert("RGB") 110 | # mask_ = mask_.resize((self.config.IMG_HEIGHT, self.config.IMG_WIDTH)) 111 | mask_ = np.asarray(mask_) 112 | if i == 0: 113 | mask = mask_ 114 | continue 115 | mask = mask | mask_ 116 | # mask = np.expand_dims(mask, axis=-1) 117 | return mask 118 | 119 | # read all training data and save them to other folder 120 | def prepare_training_data(self): 121 | # get imageId 122 | train_ids = next(os.walk(self.stage1_train_src))[1] 123 | 124 | # read training data 125 | X_train = [] 126 | Y_train = [] 127 | print('reading training data starts...') 128 | sys.stdout.flush() 129 | for n, id_ in tqdm(enumerate(train_ids)): 130 | path = os.path.join(self.stage1_train_src, id_) 131 | dest = os.path.join(self.stage1_train_dest, id_) 132 | img = Image.open(os.path.join(path, 'images', id_ + '.png')).convert("RGB") 133 | mask = self.assemble_masks(path) 134 | img.save(os.path.join(dest, 'image.png')) 135 | Image.fromarray(mask).save(os.path.join(dest, 'mask.png')) 136 | 137 | print('reading training data done...') 138 | 139 | # read testing data and save them to other folder 140 | def prepare_testing_data(self): 141 | # get imageId 142 | test_ids = next(os.walk(self.stage1_test_src))[1] 143 | # read training data 144 | print('reading testing data starts...') 145 | sys.stdout.flush() 146 | for n, id_ in tqdm(enumerate(test_ids)): 147 | path = os.path.join(self.stage1_test_src, id_, 'images', id_+'.png') 148 | dest = os.path.join(self.stage1_test_dest, id_) 149 | if not os.path.exists(dest): 150 | os.mkdir(dest) 151 | img = Image.open(path).convert("RGB") 152 | img.save(os.path.join(dest, 'image.png')) 153 | 154 | print('reading testing data done...') 155 | 156 | 157 | def compute_iou(predictions, img_ids, val_loader): 158 | """ 159 | compute IOU between two combined masks, this does not follow kaggle's evaluation 160 | :return: IOU, between 0 and 1 161 | """ 162 | ious = [] 163 | for i in range(0, len(img_ids)): 164 | pred = predictions[i] 165 | img_id = img_ids[i] 166 | mask_path = os.path.join(Option.root_dir, img_id, 'mask.png') 167 | mask = np.asarray(Image.open(mask_path).convert('L'), dtype=np.bool) 168 | union = np.sum(np.logical_or(mask, pred)) 169 | intersection = np.sum(np.logical_and(mask, pred)) 170 | iou = intersection/union 171 | ious.append(iou) 172 | df = pd.DataFrame({'img_id':img_ids,'iou':ious}) 173 | df.to_csv('IOU.csv', index=False) 174 | 175 | 176 | 177 | # Run-length encoding stolen from https://www.kaggle.com/rakhlin/fast-run-length-encoding-python 178 | def rle_encoding(x): 179 | dots = np.where(x.T.flatten() == 1)[0] 180 | run_lengths = [] 181 | prev = -2 182 | for b in dots: 183 | if (b>prev+1): run_lengths.extend((b + 1, 0)) 184 | run_lengths[-1] += 1 185 | prev = b 186 | return run_lengths 187 | 188 | def prob_to_rles(x, cutoff=0.5): 189 | lab_img = label(x > cutoff) 190 | for i in range(1, lab_img.max() + 1): 191 | yield rle_encoding(lab_img == i) 192 | 193 | def encode_and_save(preds_test_upsampled, test_ids): 194 | """ 195 | Use run-length-encoding encode the prediction masks and save to csv file for submitting 196 | :param preds_test_upsampled: list, for each elements, numpy array (Width, Height) 197 | :param test_ids: list, for each elements, image id 198 | :return: 199 | save to csv file 200 | """ 201 | # save as imgs 202 | for i in range(0, len(test_ids)): 203 | path = os.path.join(Option.results_dir, test_ids[i]) 204 | if not os.path.exists(path): 205 | os.mkdir(path) 206 | # Image.fromarray(preds_test_upsampled[i]).save(os.path.join(path,'prediction.png')) 207 | plt.imsave(os.path.join(path, 'prediction.png'),preds_test_upsampled[i], cmap='gray') 208 | # save as encoding 209 | new_test_ids = [] 210 | rles = [] 211 | for n, id_ in enumerate(test_ids): 212 | rle = list(prob_to_rles(preds_test_upsampled[n])) 213 | rles.extend(rle) 214 | new_test_ids.extend([id_] * len(rle)) 215 | 216 | sub = pd.DataFrame() 217 | sub['ImageId'] = new_test_ids 218 | sub['EncodedPixels'] = pd.Series(rles).apply(lambda x: ' '.join(str(y) for y in x)) 219 | sub.to_csv('sub-dsbowl2018.csv', index=False) 220 | 221 | if __name__ == '__main__': 222 | """ Prepare training data and testing data 223 | read data and overlay masks and save to destination path 224 | """ 225 | stage1_train_src = '/home/liming/Documents/dataset/dataScienceBowl2018/stage1_train' 226 | stage1_train_dest = '/home/liming/Documents/dataset/dataScienceBowl2018/combined' 227 | stage1_test_src = '/home/liming/Documents/dataset/dataScienceBowl2018/stage1_test' 228 | stage1_test_dest = '/home/liming/Documents/dataset/dataScienceBowl2018/testing_data' 229 | 230 | util = Utils(stage1_train_src, stage1_train_dest, stage1_test_src, stage1_test_dest) 231 | util.prepare_training_data() 232 | util.prepare_testing_data() --------------------------------------------------------------------------------