├── .github └── workflows │ └── python-app.yml ├── .gitignore ├── README.md ├── check_files_statistics.py ├── configs └── default.yaml ├── core ├── __init__.py ├── modules.py ├── res_unet.py ├── res_unet_plus.py └── unet.py ├── dataset └── dataloader.py ├── preprocess.py ├── requirements.txt ├── tests ├── __init__.py └── test_res_unet.py ├── train.py └── utils ├── __init__.py ├── augmentation.py ├── hparams.py ├── logger.py └── metrics.py /.github/workflows/python-app.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a single version of Python 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: Python application 5 | 6 | on: 7 | push: 8 | branches: [ master ] 9 | pull_request: 10 | branches: [ master ] 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ubuntu-latest 16 | 17 | steps: 18 | - uses: actions/checkout@v2 19 | - name: Set up Python 3.6 20 | uses: actions/setup-python@v2 21 | with: 22 | python-version: 3.6 23 | - name: Install dependencies 24 | run: | 25 | python -m pip install --upgrade pip 26 | pip install flake8 pytest 27 | pip install autopep8 || true 28 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 29 | - name: Lint with flake8 30 | run: | 31 | # stop the build if there are Python syntax errors or undefined names 32 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 33 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 34 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 35 | - name: Test with pytest 36 | run: | 37 | pytest 38 | 39 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /data 2 | .DS_Store 3 | .idea/ 4 | .ipynb_checkpoints 5 | **/.ipynb_checkpoints/* 6 | *__pycache__* 7 | *.zip 8 | /logs 9 | /.idea 10 | /*.ini 11 | .idea 12 | idea/* 13 | /data 14 | /output 15 | /logs 16 | /__pycache__ 17 | /core/__pycache__ 18 | /core/duration_modeling/__pycache__ 19 | /core/energy_predictor/__pycache__ 20 | /core/pitch_predictor/__pycache__ 21 | /dataset/__pycache__ 22 | /dataset/texts/__pycache__ 23 | /utils/__pycache__ 24 | /modules/__pycache__ 25 | /checkpoints 26 | /trace_loss.txt 27 | /unused_code.txt 28 | /test.py 29 | /rest_tts.py 30 | /preprocess.py 31 | /trace_loss_nvidia.txt 32 | /conf 33 | /dataset/text/__pycache__ 34 | /road_segmentation_ideal 35 | /hparams.py 36 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep ResUnet and ResUnet ++ (Not Maintained) 2 | Unofficial Pytorch implementation of following papers : 3 | * [Deep ResUnet](https://arxiv.org/pdf/1711.10684.pdf) 4 | * [ResUnet ++](https://arxiv.org/pdf/1911.07067.pdf) 5 | 6 | ## Note 7 | * This repo written for experimentation (fun) purpose and heavily hard coded, so avoid to use this as it is in production environement. 8 | * I only wrote ResUnet and ResUnet++ model, Unet is pre-implemented and borrows from this [repo](https://github.com/jeffwen/road_building_extraction). 9 | * Use your own pre-processing and dataloader, dataloader and pre-processing of this repo written for specific use case. 10 | * This repo only tested on [Massachusetts Roads Dataset](https://www.cs.toronto.edu/~vmnih/data/). 11 | 12 | ## Pre-processing 13 | * This pre-processing is for specific use case and follows strict directory structure. 14 | ````buildoutcfg 15 | python preprocess.py --config "config/default.yaml" --train training_files_dir --valid validation_files_dir 16 | ```` 17 | * Training and validation directories passed in `args` above should contain two folders `input` for input images and `output` for target images. And all images are of fixed square size (in this case `1500 * 1500` pixels). 18 | * Pre-processing crop each input and target image into several fixed size (in this case `224 * 224`) small cropped images and saved into `input_crop` and `mask_crop` respectively on training and validation dump directories as in `config` file. 19 | * You can change training and validation dump directories from config file i.e. `configs/default.yaml`. 20 | ## Training 21 | ```buildoutcfg 22 | python train.py --name "default" --config "config/default.yaml" 23 | ``` 24 | For Tensorboard: 25 | ``tensorboard --logdir logs/ 26 | `` 27 | ## References 28 | - [DenseASPP for Semantic Segmentation in Street Scenes](https://github.com/DeepMotionAIResearch/DenseASPP) 29 | - [ResUNet++ with Conditional Random Field](https://github.com/DebeshJha/ResUNetplusplus_with-CRF-and-TTA) 30 | - [SENet](https://github.com/moskomule/senet.pytorch) 31 | - [Road Extraction Using PyTorch](https://github.com/jeffwen/road_building_extraction) 32 | - [ASPP Module](https://medium.com/@aidanaden/deeplabv3-pytorch-code-explained-line-by-line-sort-of-19e729bb2af6) 33 | - [Deep Residual-Unet](https://arxiv.org/pdf/1711.10684.pdf) 34 | - [Squeeze-and-Excitation Networks](https://arxiv.org/pdf/1709.01507.pdf) 35 | - [ResUNet++](https://arxiv.org/pdf/1911.07067.pdf) 36 | - [Unet](https://arxiv.org/pdf/1505.04597.pdf) 37 | - [Brain tumor segmentation](https://github.com/galprz/brain-tumor-segmentation) 38 | -------------------------------------------------------------------------------- /check_files_statistics.py: -------------------------------------------------------------------------------- 1 | import hparams as hp 2 | from PIL import Image 3 | import shutil 4 | import os 5 | import glob 6 | import tqdm 7 | 8 | if __name__ == "__main__": 9 | train_dir = hp.train 10 | valid_dir = hp.valid 11 | 12 | train_mask_crop_dir = os.path.join(train_dir, "mask_crop") 13 | mask_files = glob.glob( 14 | os.path.join(train_mask_crop_dir, "**", "*.jpg"), recursive=True 15 | ) 16 | 17 | noisy_mask_files = os.path.join(train_dir, "noisy") 18 | os.makedirs(noisy_mask_files, exist_ok=True) 19 | 20 | count = 0 21 | print("Total image: ", len(mask_files)) 22 | for f in mask_files: 23 | img = Image.open(f) 24 | img.load() 25 | extrema = img.convert("L").getextrema() 26 | if extrema == (0, 0): 27 | count = count + 1 28 | shutil.copy2(f, f.replace("mask_crop", "noisy")) 29 | ## If file exists, delete it ## 30 | if os.path.isfile(f): 31 | os.remove(f) 32 | else: ## Show an error ## 33 | print("Error: %s file not found" % f) 34 | 35 | print("USeless image: ", count) 36 | -------------------------------------------------------------------------------- /configs/default.yaml: -------------------------------------------------------------------------------- 1 | train: "/mnt/Karan/ResUnet/data/training" 2 | valid: "/mnt/Karan/ResUnet/data/testing" 3 | log: "logs" 4 | logging_step: 100 5 | validation_interval: 2000 # Save and valid have same interval 6 | checkpoints: "checkpoints" 7 | 8 | batch_size: 16 9 | lr: 0.001 10 | RESNET_PLUS_PLUS: True 11 | IMAGE_SIZE: 1500 12 | CROP_SIZE: 224 -------------------------------------------------------------------------------- /core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rishikksh20/ResUnet/99985126ada5d649f8a3e2fd828cc67f1e606920/core/__init__.py -------------------------------------------------------------------------------- /core/modules.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | 5 | class ResidualConv(nn.Module): 6 | def __init__(self, input_dim, output_dim, stride, padding): 7 | super(ResidualConv, self).__init__() 8 | 9 | self.conv_block = nn.Sequential( 10 | nn.BatchNorm2d(input_dim), 11 | nn.ReLU(), 12 | nn.Conv2d( 13 | input_dim, output_dim, kernel_size=3, stride=stride, padding=padding 14 | ), 15 | nn.BatchNorm2d(output_dim), 16 | nn.ReLU(), 17 | nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1), 18 | ) 19 | self.conv_skip = nn.Sequential( 20 | nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=stride, padding=1), 21 | nn.BatchNorm2d(output_dim), 22 | ) 23 | 24 | def forward(self, x): 25 | 26 | return self.conv_block(x) + self.conv_skip(x) 27 | 28 | 29 | class Upsample(nn.Module): 30 | def __init__(self, input_dim, output_dim, kernel, stride): 31 | super(Upsample, self).__init__() 32 | 33 | self.upsample = nn.ConvTranspose2d( 34 | input_dim, output_dim, kernel_size=kernel, stride=stride 35 | ) 36 | 37 | def forward(self, x): 38 | return self.upsample(x) 39 | 40 | 41 | class Squeeze_Excite_Block(nn.Module): 42 | def __init__(self, channel, reduction=16): 43 | super(Squeeze_Excite_Block, self).__init__() 44 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 45 | self.fc = nn.Sequential( 46 | nn.Linear(channel, channel // reduction, bias=False), 47 | nn.ReLU(inplace=True), 48 | nn.Linear(channel // reduction, channel, bias=False), 49 | nn.Sigmoid(), 50 | ) 51 | 52 | def forward(self, x): 53 | b, c, _, _ = x.size() 54 | y = self.avg_pool(x).view(b, c) 55 | y = self.fc(y).view(b, c, 1, 1) 56 | return x * y.expand_as(x) 57 | 58 | 59 | class ASPP(nn.Module): 60 | def __init__(self, in_dims, out_dims, rate=[6, 12, 18]): 61 | super(ASPP, self).__init__() 62 | 63 | self.aspp_block1 = nn.Sequential( 64 | nn.Conv2d( 65 | in_dims, out_dims, 3, stride=1, padding=rate[0], dilation=rate[0] 66 | ), 67 | nn.ReLU(inplace=True), 68 | nn.BatchNorm2d(out_dims), 69 | ) 70 | self.aspp_block2 = nn.Sequential( 71 | nn.Conv2d( 72 | in_dims, out_dims, 3, stride=1, padding=rate[1], dilation=rate[1] 73 | ), 74 | nn.ReLU(inplace=True), 75 | nn.BatchNorm2d(out_dims), 76 | ) 77 | self.aspp_block3 = nn.Sequential( 78 | nn.Conv2d( 79 | in_dims, out_dims, 3, stride=1, padding=rate[2], dilation=rate[2] 80 | ), 81 | nn.ReLU(inplace=True), 82 | nn.BatchNorm2d(out_dims), 83 | ) 84 | 85 | self.output = nn.Conv2d(len(rate) * out_dims, out_dims, 1) 86 | self._init_weights() 87 | 88 | def forward(self, x): 89 | x1 = self.aspp_block1(x) 90 | x2 = self.aspp_block2(x) 91 | x3 = self.aspp_block3(x) 92 | out = torch.cat([x1, x2, x3], dim=1) 93 | return self.output(out) 94 | 95 | def _init_weights(self): 96 | for m in self.modules(): 97 | if isinstance(m, nn.Conv2d): 98 | nn.init.kaiming_normal_(m.weight) 99 | elif isinstance(m, nn.BatchNorm2d): 100 | m.weight.data.fill_(1) 101 | m.bias.data.zero_() 102 | 103 | 104 | class Upsample_(nn.Module): 105 | def __init__(self, scale=2): 106 | super(Upsample_, self).__init__() 107 | 108 | self.upsample = nn.Upsample(mode="bilinear", scale_factor=scale) 109 | 110 | def forward(self, x): 111 | return self.upsample(x) 112 | 113 | 114 | class AttentionBlock(nn.Module): 115 | def __init__(self, input_encoder, input_decoder, output_dim): 116 | super(AttentionBlock, self).__init__() 117 | 118 | self.conv_encoder = nn.Sequential( 119 | nn.BatchNorm2d(input_encoder), 120 | nn.ReLU(), 121 | nn.Conv2d(input_encoder, output_dim, 3, padding=1), 122 | nn.MaxPool2d(2, 2), 123 | ) 124 | 125 | self.conv_decoder = nn.Sequential( 126 | nn.BatchNorm2d(input_decoder), 127 | nn.ReLU(), 128 | nn.Conv2d(input_decoder, output_dim, 3, padding=1), 129 | ) 130 | 131 | self.conv_attn = nn.Sequential( 132 | nn.BatchNorm2d(output_dim), 133 | nn.ReLU(), 134 | nn.Conv2d(output_dim, 1, 1), 135 | ) 136 | 137 | def forward(self, x1, x2): 138 | out = self.conv_encoder(x1) + self.conv_decoder(x2) 139 | out = self.conv_attn(out) 140 | return out * x2 141 | -------------------------------------------------------------------------------- /core/res_unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from core.modules import ResidualConv, Upsample 4 | 5 | 6 | class ResUnet(nn.Module): 7 | def __init__(self, channel, filters=[64, 128, 256, 512]): 8 | super(ResUnet, self).__init__() 9 | 10 | self.input_layer = nn.Sequential( 11 | nn.Conv2d(channel, filters[0], kernel_size=3, padding=1), 12 | nn.BatchNorm2d(filters[0]), 13 | nn.ReLU(), 14 | nn.Conv2d(filters[0], filters[0], kernel_size=3, padding=1), 15 | ) 16 | self.input_skip = nn.Sequential( 17 | nn.Conv2d(channel, filters[0], kernel_size=3, padding=1) 18 | ) 19 | 20 | self.residual_conv_1 = ResidualConv(filters[0], filters[1], 2, 1) 21 | self.residual_conv_2 = ResidualConv(filters[1], filters[2], 2, 1) 22 | 23 | self.bridge = ResidualConv(filters[2], filters[3], 2, 1) 24 | 25 | self.upsample_1 = Upsample(filters[3], filters[3], 2, 2) 26 | self.up_residual_conv1 = ResidualConv(filters[3] + filters[2], filters[2], 1, 1) 27 | 28 | self.upsample_2 = Upsample(filters[2], filters[2], 2, 2) 29 | self.up_residual_conv2 = ResidualConv(filters[2] + filters[1], filters[1], 1, 1) 30 | 31 | self.upsample_3 = Upsample(filters[1], filters[1], 2, 2) 32 | self.up_residual_conv3 = ResidualConv(filters[1] + filters[0], filters[0], 1, 1) 33 | 34 | self.output_layer = nn.Sequential( 35 | nn.Conv2d(filters[0], 1, 1, 1), 36 | nn.Sigmoid(), 37 | ) 38 | 39 | def forward(self, x): 40 | # Encode 41 | x1 = self.input_layer(x) + self.input_skip(x) 42 | x2 = self.residual_conv_1(x1) 43 | x3 = self.residual_conv_2(x2) 44 | # Bridge 45 | x4 = self.bridge(x3) 46 | # Decode 47 | x4 = self.upsample_1(x4) 48 | x5 = torch.cat([x4, x3], dim=1) 49 | 50 | x6 = self.up_residual_conv1(x5) 51 | 52 | x6 = self.upsample_2(x6) 53 | x7 = torch.cat([x6, x2], dim=1) 54 | 55 | x8 = self.up_residual_conv2(x7) 56 | 57 | x8 = self.upsample_3(x8) 58 | x9 = torch.cat([x8, x1], dim=1) 59 | 60 | x10 = self.up_residual_conv3(x9) 61 | 62 | output = self.output_layer(x10) 63 | 64 | return output 65 | -------------------------------------------------------------------------------- /core/res_unet_plus.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from core.modules import ( 4 | ResidualConv, 5 | ASPP, 6 | AttentionBlock, 7 | Upsample_, 8 | Squeeze_Excite_Block, 9 | ) 10 | 11 | 12 | class ResUnetPlusPlus(nn.Module): 13 | def __init__(self, channel, filters=[32, 64, 128, 256, 512]): 14 | super(ResUnetPlusPlus, self).__init__() 15 | 16 | self.input_layer = nn.Sequential( 17 | nn.Conv2d(channel, filters[0], kernel_size=3, padding=1), 18 | nn.BatchNorm2d(filters[0]), 19 | nn.ReLU(), 20 | nn.Conv2d(filters[0], filters[0], kernel_size=3, padding=1), 21 | ) 22 | self.input_skip = nn.Sequential( 23 | nn.Conv2d(channel, filters[0], kernel_size=3, padding=1) 24 | ) 25 | 26 | self.squeeze_excite1 = Squeeze_Excite_Block(filters[0]) 27 | 28 | self.residual_conv1 = ResidualConv(filters[0], filters[1], 2, 1) 29 | 30 | self.squeeze_excite2 = Squeeze_Excite_Block(filters[1]) 31 | 32 | self.residual_conv2 = ResidualConv(filters[1], filters[2], 2, 1) 33 | 34 | self.squeeze_excite3 = Squeeze_Excite_Block(filters[2]) 35 | 36 | self.residual_conv3 = ResidualConv(filters[2], filters[3], 2, 1) 37 | 38 | self.aspp_bridge = ASPP(filters[3], filters[4]) 39 | 40 | self.attn1 = AttentionBlock(filters[2], filters[4], filters[4]) 41 | self.upsample1 = Upsample_(2) 42 | self.up_residual_conv1 = ResidualConv(filters[4] + filters[2], filters[3], 1, 1) 43 | 44 | self.attn2 = AttentionBlock(filters[1], filters[3], filters[3]) 45 | self.upsample2 = Upsample_(2) 46 | self.up_residual_conv2 = ResidualConv(filters[3] + filters[1], filters[2], 1, 1) 47 | 48 | self.attn3 = AttentionBlock(filters[0], filters[2], filters[2]) 49 | self.upsample3 = Upsample_(2) 50 | self.up_residual_conv3 = ResidualConv(filters[2] + filters[0], filters[1], 1, 1) 51 | 52 | self.aspp_out = ASPP(filters[1], filters[0]) 53 | 54 | self.output_layer = nn.Sequential(nn.Conv2d(filters[0], 1, 1), nn.Sigmoid()) 55 | 56 | def forward(self, x): 57 | x1 = self.input_layer(x) + self.input_skip(x) 58 | 59 | x2 = self.squeeze_excite1(x1) 60 | x2 = self.residual_conv1(x2) 61 | 62 | x3 = self.squeeze_excite2(x2) 63 | x3 = self.residual_conv2(x3) 64 | 65 | x4 = self.squeeze_excite3(x3) 66 | x4 = self.residual_conv3(x4) 67 | 68 | x5 = self.aspp_bridge(x4) 69 | 70 | x6 = self.attn1(x3, x5) 71 | x6 = self.upsample1(x6) 72 | x6 = torch.cat([x6, x3], dim=1) 73 | x6 = self.up_residual_conv1(x6) 74 | 75 | x7 = self.attn2(x2, x6) 76 | x7 = self.upsample2(x7) 77 | x7 = torch.cat([x7, x2], dim=1) 78 | x7 = self.up_residual_conv2(x7) 79 | 80 | x8 = self.attn3(x1, x7) 81 | x8 = self.upsample3(x8) 82 | x8 = torch.cat([x8, x1], dim=1) 83 | x8 = self.up_residual_conv3(x8) 84 | 85 | x9 = self.aspp_out(x8) 86 | out = self.output_layer(x9) 87 | 88 | return out 89 | -------------------------------------------------------------------------------- /core/unet.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | import torch 4 | 5 | # encoding block 6 | class encoding_block(nn.Module): 7 | """ 8 | Convolutional batch norm block with relu activation (main block used in the encoding steps) 9 | """ 10 | 11 | def __init__( 12 | self, 13 | in_size, 14 | out_size, 15 | kernel_size=3, 16 | padding=0, 17 | stride=1, 18 | dilation=1, 19 | batch_norm=True, 20 | dropout=False, 21 | ): 22 | super().__init__() 23 | 24 | if batch_norm: 25 | 26 | # reflection padding for same size output as input (reflection padding has shown better results than zero padding) 27 | layers = [ 28 | nn.ReflectionPad2d(padding=(kernel_size - 1) // 2), 29 | nn.Conv2d( 30 | in_size, 31 | out_size, 32 | kernel_size=kernel_size, 33 | padding=padding, 34 | stride=stride, 35 | dilation=dilation, 36 | ), 37 | nn.PReLU(), 38 | nn.BatchNorm2d(out_size), 39 | nn.ReflectionPad2d(padding=(kernel_size - 1) // 2), 40 | nn.Conv2d( 41 | out_size, 42 | out_size, 43 | kernel_size=kernel_size, 44 | padding=padding, 45 | stride=stride, 46 | dilation=dilation, 47 | ), 48 | nn.PReLU(), 49 | nn.BatchNorm2d(out_size), 50 | ] 51 | 52 | else: 53 | layers = [ 54 | nn.ReflectionPad2d(padding=(kernel_size - 1) // 2), 55 | nn.Conv2d( 56 | in_size, 57 | out_size, 58 | kernel_size=kernel_size, 59 | padding=padding, 60 | stride=stride, 61 | dilation=dilation, 62 | ), 63 | nn.PReLU(), 64 | nn.ReflectionPad2d(padding=(kernel_size - 1) // 2), 65 | nn.Conv2d( 66 | out_size, 67 | out_size, 68 | kernel_size=kernel_size, 69 | padding=padding, 70 | stride=stride, 71 | dilation=dilation, 72 | ), 73 | nn.PReLU(), 74 | ] 75 | 76 | if dropout: 77 | layers.append(nn.Dropout()) 78 | 79 | self.encoding_block = nn.Sequential(*layers) 80 | 81 | def forward(self, input): 82 | 83 | output = self.encoding_block(input) 84 | 85 | return output 86 | 87 | 88 | # decoding block 89 | class decoding_block(nn.Module): 90 | def __init__(self, in_size, out_size, batch_norm=False, upsampling=True): 91 | super().__init__() 92 | 93 | if upsampling: 94 | self.up = nn.Sequential( 95 | nn.Upsample(mode="bilinear", scale_factor=2), 96 | nn.Conv2d(in_size, out_size, kernel_size=1), 97 | ) 98 | 99 | else: 100 | self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, stride=2) 101 | 102 | self.conv = encoding_block(in_size, out_size, batch_norm=batch_norm) 103 | 104 | def forward(self, input1, input2): 105 | 106 | output2 = self.up(input2) 107 | 108 | output1 = nn.functional.upsample(input1, output2.size()[2:], mode="bilinear") 109 | 110 | return self.conv(torch.cat([output1, output2], 1)) 111 | 112 | 113 | class UNet(nn.Module): 114 | """ 115 | Main UNet architecture 116 | """ 117 | 118 | def __init__(self, num_classes=1): 119 | super().__init__() 120 | 121 | # encoding 122 | self.conv1 = encoding_block(3, 64) 123 | self.maxpool1 = nn.MaxPool2d(kernel_size=2) 124 | 125 | self.conv2 = encoding_block(64, 128) 126 | self.maxpool2 = nn.MaxPool2d(kernel_size=2) 127 | 128 | self.conv3 = encoding_block(128, 256) 129 | self.maxpool3 = nn.MaxPool2d(kernel_size=2) 130 | 131 | self.conv4 = encoding_block(256, 512) 132 | self.maxpool4 = nn.MaxPool2d(kernel_size=2) 133 | 134 | # center 135 | self.center = encoding_block(512, 1024) 136 | 137 | # decoding 138 | self.decode4 = decoding_block(1024, 512) 139 | self.decode3 = decoding_block(512, 256) 140 | self.decode2 = decoding_block(256, 128) 141 | self.decode1 = decoding_block(128, 64) 142 | 143 | # final 144 | self.final = nn.Conv2d(64, num_classes, kernel_size=1) 145 | 146 | def forward(self, input): 147 | 148 | # encoding 149 | conv1 = self.conv1(input) 150 | maxpool1 = self.maxpool1(conv1) 151 | 152 | conv2 = self.conv2(maxpool1) 153 | maxpool2 = self.maxpool2(conv2) 154 | 155 | conv3 = self.conv3(maxpool2) 156 | maxpool3 = self.maxpool3(conv3) 157 | 158 | conv4 = self.conv4(maxpool3) 159 | maxpool4 = self.maxpool4(conv4) 160 | 161 | # center 162 | center = self.center(maxpool4) 163 | 164 | # decoding 165 | decode4 = self.decode4(conv4, center) 166 | 167 | decode3 = self.decode3(conv3, decode4) 168 | 169 | decode2 = self.decode2(conv2, decode3) 170 | 171 | decode1 = self.decode1(conv1, decode2) 172 | 173 | # final 174 | final = nn.functional.upsample( 175 | self.final(decode1), input.size()[2:], mode="bilinear" 176 | ) 177 | 178 | return final 179 | 180 | 181 | class UNetSmall(nn.Module): 182 | """ 183 | Main UNet architecture 184 | """ 185 | 186 | def __init__(self, num_classes=1): 187 | super().__init__() 188 | 189 | # encoding 190 | self.conv1 = encoding_block(3, 32) 191 | self.maxpool1 = nn.MaxPool2d(kernel_size=2) 192 | 193 | self.conv2 = encoding_block(32, 64) 194 | self.maxpool2 = nn.MaxPool2d(kernel_size=2) 195 | 196 | self.conv3 = encoding_block(64, 128) 197 | self.maxpool3 = nn.MaxPool2d(kernel_size=2) 198 | 199 | self.conv4 = encoding_block(128, 256) 200 | self.maxpool4 = nn.MaxPool2d(kernel_size=2) 201 | 202 | # center 203 | self.center = encoding_block(256, 512) 204 | 205 | # decoding 206 | self.decode4 = decoding_block(512, 256) 207 | self.decode3 = decoding_block(256, 128) 208 | self.decode2 = decoding_block(128, 64) 209 | self.decode1 = decoding_block(64, 32) 210 | 211 | # final 212 | self.final = nn.Conv2d(32, num_classes, kernel_size=1) 213 | 214 | def forward(self, input): 215 | 216 | # encoding 217 | conv1 = self.conv1(input) 218 | maxpool1 = self.maxpool1(conv1) 219 | 220 | conv2 = self.conv2(maxpool1) 221 | maxpool2 = self.maxpool2(conv2) 222 | 223 | conv3 = self.conv3(maxpool2) 224 | maxpool3 = self.maxpool3(conv3) 225 | 226 | conv4 = self.conv4(maxpool3) 227 | maxpool4 = self.maxpool4(conv4) 228 | 229 | # center 230 | center = self.center(maxpool4) 231 | 232 | # decoding 233 | decode4 = self.decode4(conv4, center) 234 | 235 | decode3 = self.decode3(conv3, decode4) 236 | 237 | decode2 = self.decode2(conv2, decode3) 238 | 239 | decode1 = self.decode1(conv1, decode2) 240 | 241 | # final 242 | final = nn.functional.upsample( 243 | self.final(decode1), input.size()[2:], mode="bilinear" 244 | ) 245 | 246 | return final 247 | -------------------------------------------------------------------------------- /dataset/dataloader.py: -------------------------------------------------------------------------------- 1 | """" Modified version of https://github.com/jeffwen/road_building_extraction/blob/master/src/utils/data_utils.py """ 2 | from __future__ import print_function, division 3 | from torch.utils.data import Dataset 4 | from skimage import io 5 | import glob 6 | import os 7 | import torch 8 | from torchvision import transforms 9 | 10 | 11 | class ImageDataset(Dataset): 12 | """Massachusetts Road and Building dataset""" 13 | 14 | def __init__(self, hp, train=True, transform=None): 15 | """ 16 | Args: 17 | csv_file (string): Path to the csv file with image paths 18 | train_valid_test (string): 'train', 'valid', or 'test' 19 | root_dir (string): 'mass_roads', 'mass_roads_crop', or 'mass_buildings' 20 | transform (callable, optional): Optional transform to be applied on a sample. 21 | """ 22 | self.train = train 23 | self.path = hp.train if train else hp.valid 24 | self.mask_list = glob.glob( 25 | os.path.join(self.path, "mask_crop", "*.jpg"), recursive=True 26 | ) 27 | self.transform = transform 28 | 29 | def __len__(self): 30 | return len(self.mask_list) 31 | 32 | def __getitem__(self, idx): 33 | maskpath = self.mask_list[idx] 34 | image = io.imread(maskpath.replace("mask_crop", "input_crop")) 35 | mask = io.imread(maskpath) 36 | 37 | sample = {"sat_img": image, "map_img": mask} 38 | 39 | if self.transform: 40 | sample = self.transform(sample) 41 | 42 | return sample 43 | 44 | 45 | class ToTensorTarget(object): 46 | """Convert ndarrays in sample to Tensors.""" 47 | 48 | def __call__(self, sample): 49 | sat_img, map_img = sample["sat_img"], sample["map_img"] 50 | 51 | # swap color axis because 52 | # numpy image: H x W x C 53 | # torch image: C X H X W 54 | 55 | return { 56 | "sat_img": transforms.functional.to_tensor(sat_img), 57 | "map_img": torch.from_numpy(map_img).unsqueeze(0).float().div(255), 58 | } # unsqueeze for the channel dimension 59 | 60 | 61 | class NormalizeTarget(transforms.Normalize): 62 | """Normalize a tensor and also return the target""" 63 | 64 | def __call__(self, sample): 65 | return { 66 | "sat_img": transforms.functional.normalize( 67 | sample["sat_img"], self.mean, self.std 68 | ), 69 | "map_img": sample["map_img"], 70 | } 71 | 72 | 73 | # https://discuss.pytorch.org/t/simple-way-to-inverse-transform-normalization/4821/3 74 | class UnNormalize(object): 75 | def __init__(self, mean, std): 76 | self.mean = mean 77 | self.std = std 78 | 79 | def __call__(self, tensor): 80 | """ 81 | Args: 82 | tensor (Tensor): Tensor image of size (C, H, W) to be normalized. 83 | Returns: 84 | Tensor: Normalized image. 85 | """ 86 | for t, m, s in zip(tensor, self.mean, self.std): 87 | t.mul_(s).add_(m) 88 | # The normalize code -> t.sub_(m).div_(s) 89 | return tensor 90 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import numpy as np 3 | import os 4 | import glob 5 | import tqdm 6 | import argparse 7 | from utils.hparams import HParam 8 | 9 | def load_image( infilename) : 10 | img = Image.open( infilename ) 11 | img.load() 12 | if img.mode == 'P': 13 | img.convert('RGB') 14 | data = np.asarray( img, dtype="int32" ) 15 | return data 16 | 17 | def start_points(size, split_size, overlap=0): 18 | points = [0] 19 | stride = int(split_size * (1-overlap)) 20 | counter = 1 21 | while True: 22 | pt = stride * counter 23 | if pt + split_size >= size: 24 | points.append(size - split_size) 25 | break 26 | else: 27 | points.append(pt) 28 | counter += 1 29 | return points 30 | 31 | def crop_image_mask(image_dir, mask_dir, mask_path, X_points, Y_points, split_height=224, split_width=224): 32 | img_id = os.path.basename(mask_path).split(".")[0] 33 | mask = load_image(mask_path) 34 | img = load_image(mask_path.replace("output", "input")) 35 | 36 | count = 0 37 | num_skipped = 1 38 | for i in Y_points: 39 | for j in X_points: 40 | new_image = img[i:i + split_height, j:j + split_width] 41 | new_mask = mask[i:i + split_height, j:j + split_width] 42 | new_mask[new_mask > 1] = 255 43 | # Skip any Image that is more than 99% empty. 44 | if np.any(new_mask): 45 | num_black_pixels, num_white_pixels = np.unique(new_mask, return_counts=True)[1] 46 | 47 | if num_white_pixels / num_black_pixels < 0.01: 48 | num_skipped += 1 49 | continue 50 | 51 | mask_ = Image.fromarray(new_mask.astype(np.uint8)) 52 | mask_.save("{}/{}_{}.jpg".format(mask_dir, img_id, count), "JPEG") 53 | im = Image.fromarray(new_image.astype(np.uint8)) 54 | im.save("{}/{}_{}.jpg".format(image_dir, img_id, count), "JPEG") 55 | count = count + 1 56 | 57 | if __name__ == '__main__': 58 | 59 | parser = argparse.ArgumentParser() 60 | parser.add_argument('-c', '--config', type=str, required=True, 61 | help="yaml file for configuration") 62 | parser.add_argument('-t', '--train', type=str, required=True, 63 | help="Training Folder.") 64 | parser.add_argument('-v', '--valid', type=str, required=True, 65 | help="Validation Folder") 66 | args = parser.parse_args() 67 | 68 | hp = HParam(args.config) 69 | with open(args.config, 'r') as f: 70 | hp_str = ''.join(f.readlines()) 71 | 72 | train_dir = args.train 73 | valid_dir = args.valid 74 | X_points = start_points(hp.IMAGE_SIZE, hp.CROP_SIZE, 0.14) 75 | Y_points = start_points(hp.IMAGE_SIZE, hp.CROP_SIZE, 0.14) 76 | 77 | ## Training data 78 | train_img_dir = os.path.join(train_dir, "input") 79 | train_mask_dir = os.path.join(train_dir, "output") 80 | train_img_crop_dir = os.path.join(hp.train, "input_crop") 81 | os.makedirs(train_img_crop_dir, exist_ok=True) 82 | train_mask_crop_dir = os.path.join(hp.train, "mask_crop") 83 | os.makedirs(train_mask_crop_dir, exist_ok=True) 84 | 85 | img_files = glob.glob(os.path.join(train_img_dir, '**', '*.png'), recursive=True) 86 | mask_files = glob.glob(os.path.join(train_mask_dir, '**', '*.png'), recursive=True) 87 | print("Length of image :", len(img_files)) 88 | print("Length of mask :", len(mask_files)) 89 | #assert len(img_files) == len(mask_files) 90 | 91 | 92 | 93 | for mask_path in tqdm.tqdm(mask_files, desc='Cropping Training images'): 94 | crop_image_mask(train_img_crop_dir, train_mask_crop_dir, mask_path, X_points, Y_points) 95 | 96 | ### Validation data 97 | valid_img_dir = os.path.join(valid_dir, "input") 98 | valid_mask_dir = os.path.join(valid_dir, "output") 99 | valid_img_crop_dir = os.path.join(hp.valid, "input_crop") 100 | os.makedirs(valid_img_crop_dir, exist_ok=True) 101 | valid_mask_crop_dir = os.path.join(hp.valid, "mask_crop") 102 | os.makedirs(valid_mask_crop_dir, exist_ok=True) 103 | 104 | img_files = glob.glob(os.path.join(valid_img_dir, '**', '*.png'), recursive=True) 105 | mask_files = glob.glob(os.path.join(valid_mask_dir, '**', '*.png'), recursive=True) 106 | assert len(img_files) == len(mask_files) 107 | 108 | 109 | for mask_path in tqdm.tqdm(mask_files, desc='Cropping Validation images'): 110 | crop_image_mask(valid_img_crop_dir, valid_mask_crop_dir, mask_path, X_points, Y_points) 111 | 112 | 113 | 114 | 115 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.6.0 2 | torchvision 3 | matplotlib==2.1.0 4 | tensorboardX 5 | numpy==1.16.3 6 | librosa==0.7.2 7 | inflect==0.2.5 8 | scipy==1.0.0 9 | Unidecode==1.0.22 10 | pillow 11 | tqdm 12 | configargparse 13 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rishikksh20/ResUnet/99985126ada5d649f8a3e2fd828cc67f1e606920/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_res_unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from core.res_unet import ResUnet, ResidualConv, Upsample 3 | 4 | def test_resunet(): 5 | img = torch.ones(1, 3, 224, 224) 6 | resunet = ResUnet(3) 7 | assert resunet(img).shape == torch.Size([1, 1, 224, 224]) 8 | 9 | 10 | def test_residual_conv(): 11 | x = torch.ones(1, 64, 224, 224) 12 | res_conv = ResidualConv(64, 128, 2, 1) 13 | assert res_conv(x).shape == torch.Size([1, 128, 112, 112]) 14 | 15 | 16 | def test_upsample(): 17 | x = torch.ones(1, 512, 28, 28) 18 | upsample = Upsample(512, 512, 2, 2) 19 | assert upsample(x).shape == torch.Size([1, 512, 56, 56]) 20 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | warnings.simplefilter("ignore", (UserWarning, FutureWarning)) 4 | from utils.hparams import HParam 5 | from torch.utils.data import DataLoader 6 | from torchvision import transforms 7 | from tqdm import tqdm 8 | from dataset import dataloader 9 | from utils import metrics 10 | from core.res_unet import ResUnet 11 | from core.res_unet_plus import ResUnetPlusPlus 12 | from utils.logger import MyWriter 13 | import torch 14 | import argparse 15 | import os 16 | 17 | 18 | def main(hp, num_epochs, resume, name): 19 | 20 | checkpoint_dir = "{}/{}".format(hp.checkpoints, name) 21 | os.makedirs(checkpoint_dir, exist_ok=True) 22 | 23 | os.makedirs("{}/{}".format(hp.log, name), exist_ok=True) 24 | writer = MyWriter("{}/{}".format(hp.log, name)) 25 | # get model 26 | 27 | if hp.RESNET_PLUS_PLUS: 28 | model = ResUnetPlusPlus(3).cuda() 29 | else: 30 | model = ResUnet(3, 64).cuda() 31 | 32 | # set up binary cross entropy and dice loss 33 | criterion = metrics.BCEDiceLoss() 34 | 35 | # optimizer 36 | # optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum, nesterov=True) 37 | # optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-5) 38 | optimizer = torch.optim.Adam(model.parameters(), lr=hp.lr) 39 | 40 | # decay LR 41 | lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1) 42 | 43 | # starting params 44 | best_loss = 999 45 | start_epoch = 0 46 | # optionally resume from a checkpoint 47 | if resume: 48 | if os.path.isfile(resume): 49 | print("=> loading checkpoint '{}'".format(resume)) 50 | checkpoint = torch.load(resume) 51 | 52 | start_epoch = checkpoint["epoch"] 53 | 54 | best_loss = checkpoint["best_loss"] 55 | model.load_state_dict(checkpoint["state_dict"]) 56 | optimizer.load_state_dict(checkpoint["optimizer"]) 57 | print( 58 | "=> loaded checkpoint '{}' (epoch {})".format( 59 | resume, checkpoint["epoch"] 60 | ) 61 | ) 62 | else: 63 | print("=> no checkpoint found at '{}'".format(args.resume)) 64 | 65 | # get data 66 | mass_dataset_train = dataloader.ImageDataset( 67 | hp, transform=transforms.Compose([dataloader.ToTensorTarget()]) 68 | ) 69 | 70 | mass_dataset_val = dataloader.ImageDataset( 71 | hp, False, transform=transforms.Compose([dataloader.ToTensorTarget()]) 72 | ) 73 | 74 | # creating loaders 75 | train_dataloader = DataLoader( 76 | mass_dataset_train, batch_size=hp.batch_size, num_workers=2, shuffle=True 77 | ) 78 | val_dataloader = DataLoader( 79 | mass_dataset_val, batch_size=1, num_workers=2, shuffle=False 80 | ) 81 | 82 | step = 0 83 | for epoch in range(start_epoch, num_epochs): 84 | print("Epoch {}/{}".format(epoch, num_epochs - 1)) 85 | print("-" * 10) 86 | 87 | # step the learning rate scheduler 88 | lr_scheduler.step() 89 | 90 | # run training and validation 91 | # logging accuracy and loss 92 | train_acc = metrics.MetricTracker() 93 | train_loss = metrics.MetricTracker() 94 | # iterate over data 95 | 96 | loader = tqdm(train_dataloader, desc="training") 97 | for idx, data in enumerate(loader): 98 | 99 | # get the inputs and wrap in Variable 100 | inputs = data["sat_img"].cuda() 101 | labels = data["map_img"].cuda() 102 | 103 | # zero the parameter gradients 104 | optimizer.zero_grad() 105 | 106 | # forward 107 | # prob_map = model(inputs) # last activation was a sigmoid 108 | # outputs = (prob_map > 0.3).float() 109 | outputs = model(inputs) 110 | # outputs = torch.nn.functional.sigmoid(outputs) 111 | 112 | loss = criterion(outputs, labels) 113 | 114 | # backward 115 | loss.backward() 116 | optimizer.step() 117 | 118 | train_acc.update(metrics.dice_coeff(outputs, labels), outputs.size(0)) 119 | train_loss.update(loss.data.item(), outputs.size(0)) 120 | 121 | # tensorboard logging 122 | if step % hp.logging_step == 0: 123 | writer.log_training(train_loss.avg, train_acc.avg, step) 124 | loader.set_description( 125 | "Training Loss: {:.4f} Acc: {:.4f}".format( 126 | train_loss.avg, train_acc.avg 127 | ) 128 | ) 129 | 130 | # Validatiuon 131 | if step % hp.validation_interval == 0: 132 | valid_metrics = validation( 133 | val_dataloader, model, criterion, writer, step 134 | ) 135 | save_path = os.path.join( 136 | checkpoint_dir, "%s_checkpoint_%04d.pt" % (name, step) 137 | ) 138 | # store best loss and save a model checkpoint 139 | best_loss = min(valid_metrics["valid_loss"], best_loss) 140 | torch.save( 141 | { 142 | "step": step, 143 | "epoch": epoch, 144 | "arch": "ResUnet", 145 | "state_dict": model.state_dict(), 146 | "best_loss": best_loss, 147 | "optimizer": optimizer.state_dict(), 148 | }, 149 | save_path, 150 | ) 151 | print("Saved checkpoint to: %s" % save_path) 152 | 153 | step += 1 154 | 155 | 156 | def validation(valid_loader, model, criterion, logger, step): 157 | 158 | # logging accuracy and loss 159 | valid_acc = metrics.MetricTracker() 160 | valid_loss = metrics.MetricTracker() 161 | 162 | # switch to evaluate mode 163 | model.eval() 164 | 165 | # Iterate over data. 166 | for idx, data in enumerate(tqdm(valid_loader, desc="validation")): 167 | 168 | # get the inputs and wrap in Variable 169 | inputs = data["sat_img"].cuda() 170 | labels = data["map_img"].cuda() 171 | 172 | # forward 173 | # prob_map = model(inputs) # last activation was a sigmoid 174 | # outputs = (prob_map > 0.3).float() 175 | outputs = model(inputs) 176 | # outputs = torch.nn.functional.sigmoid(outputs) 177 | 178 | loss = criterion(outputs, labels) 179 | 180 | valid_acc.update(metrics.dice_coeff(outputs, labels), outputs.size(0)) 181 | valid_loss.update(loss.data.item(), outputs.size(0)) 182 | if idx == 0: 183 | logger.log_images(inputs.cpu(), labels.cpu(), outputs.cpu(), step) 184 | logger.log_validation(valid_loss.avg, valid_acc.avg, step) 185 | 186 | print("Validation Loss: {:.4f} Acc: {:.4f}".format(valid_loss.avg, valid_acc.avg)) 187 | model.train() 188 | return {"valid_loss": valid_loss.avg, "valid_acc": valid_acc.avg} 189 | 190 | 191 | if __name__ == "__main__": 192 | parser = argparse.ArgumentParser(description="Road and Building Extraction") 193 | parser.add_argument( 194 | "-c", "--config", type=str, required=True, help="yaml file for configuration" 195 | ) 196 | parser.add_argument( 197 | "--epochs", 198 | default=75, 199 | type=int, 200 | metavar="N", 201 | help="number of total epochs to run", 202 | ) 203 | parser.add_argument( 204 | "--resume", 205 | default="", 206 | type=str, 207 | metavar="PATH", 208 | help="path to latest checkpoint (default: none)", 209 | ) 210 | parser.add_argument("--name", default="default", type=str, help="Experiment name") 211 | 212 | args = parser.parse_args() 213 | 214 | hp = HParam(args.config) 215 | with open(args.config, "r") as f: 216 | hp_str = "".join(f.readlines()) 217 | 218 | main(hp, num_epochs=args.epochs, resume=args.resume, name=args.name) 219 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rishikksh20/ResUnet/99985126ada5d649f8a3e2fd828cc67f1e606920/utils/__init__.py -------------------------------------------------------------------------------- /utils/augmentation.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | warnings.simplefilter("ignore", UserWarning) 4 | 5 | from skimage import transform 6 | from torchvision import transforms 7 | 8 | import numpy as np 9 | import torch 10 | 11 | 12 | class RescaleTarget(object): 13 | """Rescale the image in a sample to a given size. 14 | 15 | Args: 16 | output_size (tuple or int): Desired output size. If tuple, output is 17 | matched to output_size. If int, smaller of image edges is matched 18 | to output_size keeping aspect ratio the same. 19 | """ 20 | 21 | def __init__(self, output_size): 22 | assert isinstance(output_size, (int, tuple)) 23 | if isinstance(output_size, tuple): 24 | self.output_size = int(np.random.uniform(output_size[0], output_size[1])) 25 | else: 26 | self.output_size = output_size 27 | 28 | def __call__(self, sample): 29 | sat_img, map_img = sample["sat_img"], sample["map_img"] 30 | 31 | h, w = sat_img.shape[:2] 32 | 33 | if h > w: 34 | new_h, new_w = self.output_size * h / w, self.output_size 35 | else: 36 | new_h, new_w = self.output_size, self.output_size * w / h 37 | 38 | new_h, new_w = int(new_h), int(new_w) 39 | 40 | # change the range to 0-1 rather than 0-255, makes it easier to use sigmoid later 41 | sat_img = transform.resize(sat_img, (new_h, new_w)) 42 | 43 | map_img = transform.resize(map_img, (new_h, new_w)) 44 | 45 | return {"sat_img": sat_img, "map_img": map_img} 46 | 47 | 48 | class RandomRotationTarget(object): 49 | """Rotate the image and target randomly in a sample. 50 | 51 | Args: 52 | degrees (tuple or int): Range of degrees to select from. 53 | If degrees is a number instead of sequence like (min, max), the range of degrees 54 | will be (-degrees, +degrees). 55 | resize (boolean): Expand the image to fit 56 | """ 57 | 58 | def __init__(self, degrees, resize=False): 59 | if isinstance(degrees, int): 60 | if degrees < 0: 61 | raise ValueError("If degrees is a single number, it must be positive.") 62 | self.degrees = (-degrees, degrees) 63 | else: 64 | if isinstance(degrees, tuple): 65 | raise ValueError("Degrees needs to be either an int or tuple") 66 | self.degrees = degrees 67 | 68 | assert isinstance(resize, bool) 69 | 70 | self.resize = resize 71 | self.angle = np.random.uniform(self.degrees[0], self.degrees[1]) 72 | 73 | def __call__(self, sample): 74 | 75 | sat_img = transform.rotate(sample["sat_img"], self.angle, self.resize) 76 | map_img = transform.rotate(sample["map_img"], self.angle, self.resize) 77 | 78 | return {"sat_img": sat_img, "map_img": map_img} 79 | 80 | 81 | class RandomCropTarget(object): 82 | """ 83 | Crop the image and target randomly in a sample. 84 | 85 | Args: 86 | output_size (tuple or int): Desired output size. If int, square crop 87 | is made. 88 | 89 | """ 90 | 91 | def __init__(self, output_size): 92 | assert isinstance(output_size, (int, tuple)) 93 | if isinstance(output_size, int): 94 | self.output_size = (output_size, output_size) 95 | else: 96 | assert len(output_size) == 2 97 | self.output_size = output_size 98 | 99 | def __call__(self, sample): 100 | 101 | sat_img, map_img = sample["sat_img"], sample["map_img"] 102 | 103 | h, w = sat_img.shape[:2] 104 | new_h, new_w = self.output_size 105 | 106 | top = np.random.randint(0, h - new_h) 107 | left = np.random.randint(0, w - new_w) 108 | 109 | sat_img = sat_img[top : top + new_h, left : left + new_w] 110 | map_img = map_img[top : top + new_h, left : left + new_w] 111 | 112 | return {"sat_img": sat_img, "map_img": map_img} 113 | -------------------------------------------------------------------------------- /utils/hparams.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/HarryVolek/PyTorch_Speaker_Verification 2 | 3 | import os 4 | import yaml 5 | 6 | 7 | def load_hparam_str(hp_str): 8 | path = "temp-restore.yaml" 9 | with open(path, "w") as f: 10 | f.write(hp_str) 11 | ret = HParam(path) 12 | os.remove(path) 13 | return ret 14 | 15 | 16 | def load_hparam(filename): 17 | stream = open(filename, "r") 18 | docs = yaml.load_all(stream, Loader=yaml.Loader) 19 | hparam_dict = dict() 20 | for doc in docs: 21 | for k, v in doc.items(): 22 | hparam_dict[k] = v 23 | return hparam_dict 24 | 25 | 26 | def merge_dict(user, default): 27 | if isinstance(user, dict) and isinstance(default, dict): 28 | for k, v in default.items(): 29 | if k not in user: 30 | user[k] = v 31 | else: 32 | user[k] = merge_dict(user[k], v) 33 | return user 34 | 35 | 36 | class Dotdict(dict): 37 | """ 38 | a dictionary that supports dot notation 39 | as well as dictionary access notation 40 | usage: d = DotDict() or d = DotDict({'val1':'first'}) 41 | set attributes: d.val2 = 'second' or d['val2'] = 'second' 42 | get attributes: d.val2 or d['val2'] 43 | """ 44 | 45 | __getattr__ = dict.__getitem__ 46 | __setattr__ = dict.__setitem__ 47 | __delattr__ = dict.__delitem__ 48 | 49 | def __init__(self, dct=None): 50 | dct = dict() if not dct else dct 51 | for key, value in dct.items(): 52 | if hasattr(value, "keys"): 53 | value = Dotdict(value) 54 | self[key] = value 55 | 56 | 57 | class HParam(Dotdict): 58 | def __init__(self, file): 59 | super(Dotdict, self).__init__() 60 | hp_dict = load_hparam(file) 61 | hp_dotdict = Dotdict(hp_dict) 62 | for k, v in hp_dotdict.items(): 63 | setattr(self, k, v) 64 | 65 | __getattr__ = Dotdict.__getitem__ 66 | __setattr__ = Dotdict.__setitem__ 67 | __delattr__ = Dotdict.__delitem__ 68 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | # adapted from https://github.com/yunjey/pytorch-tutorial/tree/master/tutorials/04-utils/tensorboard 2 | from tensorboardX import SummaryWriter 3 | from PIL import Image 4 | import numpy as np 5 | 6 | 7 | class MyWriter(SummaryWriter): 8 | def __init__(self, logdir): 9 | super(MyWriter, self).__init__(logdir) 10 | 11 | def log_training(self, dice_loss, iou, step): 12 | self.add_scalar("training/dice_loss", dice_loss, step) 13 | self.add_scalar("training/iou", iou, step) 14 | 15 | def log_validation(self, dice_loss, iou, step): 16 | self.add_scalar("validation/dice_loss", dice_loss, step) 17 | self.add_scalar("validation/iou", iou, step) 18 | 19 | def log_images(self, map, target, prediction, step): 20 | if len(map.shape) > 3: 21 | map = map.squeeze(0) 22 | if len(target.shape) > 2: 23 | target = target.squeeze() 24 | if len(prediction.shape) > 2: 25 | prediction = prediction.squeeze() 26 | self.add_image("map", map, step) 27 | self.add_image("mask", target.unsqueeze(0), step) 28 | self.add_image("prediction", prediction.unsqueeze(0), step) 29 | 30 | 31 | class LogWriter(SummaryWriter): 32 | def __init__(self, logdir): 33 | super(LogWriter, self).__init__(logdir) 34 | 35 | def log_scaler(self, key, value, step, prefix="Training", helper_func=None): 36 | if helper_func: 37 | value = helper_func(value) 38 | self.add_scalar("{}/{}".format(prefix, key), value, step) 39 | 40 | def log_image(self, key, value, step, prefix="Training", helper_func=None): 41 | if helper_func: 42 | value = helper_func(value) 43 | self.add_image("{}/{}".format(prefix, key), value, step) 44 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class BCEDiceLoss(nn.Module): 5 | def __init__(self, weight=None, size_average=True): 6 | super().__init__() 7 | 8 | def forward(self, input, target): 9 | pred = input.view(-1) 10 | truth = target.view(-1) 11 | 12 | # BCE loss 13 | bce_loss = nn.BCELoss()(pred, truth).double() 14 | 15 | # Dice Loss 16 | dice_coef = (2.0 * (pred * truth).double().sum() + 1) / ( 17 | pred.double().sum() + truth.double().sum() + 1 18 | ) 19 | 20 | return bce_loss + (1 - dice_coef) 21 | 22 | 23 | # https://github.com/pytorch/examples/blob/master/imagenet/main.py 24 | class MetricTracker(object): 25 | """Computes and stores the average and current value""" 26 | 27 | def __init__(self): 28 | self.reset() 29 | 30 | def reset(self): 31 | self.val = 0 32 | self.avg = 0 33 | self.sum = 0 34 | self.count = 0 35 | 36 | def update(self, val, n=1): 37 | self.val = val 38 | self.sum += val * n 39 | self.count += n 40 | self.avg = self.sum / self.count 41 | 42 | 43 | # https://stackoverflow.com/questions/48260415/pytorch-how-to-compute-iou-jaccard-index-for-semantic-segmentation 44 | def jaccard_index(input, target): 45 | 46 | intersection = (input * target).long().sum().data.cpu()[0] 47 | union = ( 48 | input.long().sum().data.cpu()[0] 49 | + target.long().sum().data.cpu()[0] 50 | - intersection 51 | ) 52 | 53 | if union == 0: 54 | return float("nan") 55 | else: 56 | return float(intersection) / float(max(union, 1)) 57 | 58 | 59 | # https://github.com/pytorch/pytorch/issues/1249 60 | def dice_coeff(input, target): 61 | num_in_target = input.size(0) 62 | 63 | smooth = 1.0 64 | 65 | pred = input.view(num_in_target, -1) 66 | truth = target.view(num_in_target, -1) 67 | 68 | intersection = (pred * truth).sum(1) 69 | 70 | loss = (2.0 * intersection + smooth) / (pred.sum(1) + truth.sum(1) + smooth) 71 | 72 | return loss.mean().item() 73 | --------------------------------------------------------------------------------