├── LICENSE ├── README.md ├── hrnet ├── README.md └── src │ ├── DeepNetworks │ ├── HRNet.py │ ├── ShiftNet.py │ └── __init__.py │ ├── __init__.py │ ├── lanczos.py │ ├── predict.py │ ├── train.py │ └── utils.py ├── input ├── TestAreaCyprus.geojson ├── TestAreaLithuania.geojson ├── config-local-hrn-deconv.json └── config-local-hrn-pix-shu.json ├── notebooks ├── 00-parse-deimos-metadata.ipynb ├── 00a-add-per-tile-median.ipynb ├── 00b-calculate-cloudfree-deimos-stats.ipynb ├── 01-download-to-eopatches.ipynb ├── 02a-add-clm-deimos.ipynb ├── 02b-add-clm-stats-to-patches.ipynb ├── 03-sampling.ipynb ├── 04-sampled-to-npz.ipynb ├── 05a-train-test-split.ipynb ├── 05b-find-cloudy-neighbours.ipynb ├── 05c-calculate-s2-normalizations.ipynb ├── 05d-calculate-scores.ipynb ├── 06-train.ipynb ├── 07-predict.ipynb └── 07b-predict-eopatches.ipynb ├── requirements-dev.txt ├── requirements.txt ├── setup.py ├── sr ├── __init__.py ├── data_loader.py ├── metrics.py ├── niva_models.py └── utils.py └── tests ├── conftest.py ├── data └── data_loader │ └── input │ ├── data_3m_eopatch-0277_66_2x.npz │ ├── data_eopatch-0277_66_2.npz │ ├── data_eopatch-0288_0_3.npz │ ├── deimos_min_max_norm.npz │ ├── npz_info.pq │ ├── s2_min_max_norm.npz │ └── s2_norm_per_country.pq ├── test_data_loader.py └── test_metrics.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Sentinel Hub 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DIONE - Super Resolution using Sentinel-2 and Deimos imagery 2 | This repo contains code to train a multitemporal super-resolution model for Sentinel-2 imagery using Deimos. 3 | 4 | You can find more information about this project in the blog post [Multi-temporal Super-Resolution on Sentinel-2 Imagery](https://medium.com/sentinel-hub/multi-temporal-super-resolution-on-sentinel-2-imagery-6089c2b39ebc) 5 | 6 | ## Introduction 7 | This project is part of the DIONE project where one of the missions is using novel techniques to improve the capabilities of satellite technology while integrating various data sources, such as very high resolution imagery, to, for example, enable monitoring of smaller agricultural parcels through the use of super resolution models. 8 | 9 | This project has received funding from the European Union’s Horizon 2020 research and innovation programme under grant agreement No 870378. 10 | 11 | Please visit the [Dione website](https://dione-project.eu/) for further information. 12 | ## Requirements 13 | The super resolution pipeline uses SentinelHub service to download Sentinel-2 and Deimos imagery. Amazon AWS S3 bucket was used to store the data. 14 | _Deimos imagery is not public, however any other Very High Resolution imagery can be used by adjusting the general workflow._ 15 | 16 | ## Installation and usage 17 | 18 | To install the sr package, clone locally the repository, and from within the repository, run the following commands: 19 | ``` 20 | pip install -r requirements.txt 21 | python setup.py install --user 22 | ``` 23 | Procedure is executed in notebooks, the basic functionality of each notebook is described below: 24 | * `00-parse-deimos-metadata.ipynb`: **Deimos specific**. Parses metadata for each ingested Deimos tile and saves to dataframe. 25 | * `00a-add-per-tile-median.ipynb` **Deimos specific**. Calculates median for each Deimos tile. 26 | * `00b-calculate-cloudfree-deimos-stats.ipynb` **Deimos specific**. Calculates Deimos tile statistics on cloudless areas. 27 | * `01-download-to-eopatches.ipynb` Download Sentinel-2 and ingested Deimos imagery to EOPatches 28 | * `02a-add-clm-deimos.ipynb` Add cloud mask information to Deimos EOPatches 29 | * `02b-add-clm-stats-to-patches.ipynb` Add cloudless normalization statistics to EOPatches 30 | * `03-sampling.ipynb` Sample smaller patchlets from EOPatches 31 | * `04-sampled-to-npz.ipynb` Construct NPZ files from patchlets. 32 | * `05a-train-test-split.ipynb` Split NPZ files into train/test/validation sets.. 33 | * `05b-find-cloudy-neighbours.ipynb` Shadow detection by filtering neighbours of cloudy EOPatches 34 | * `05c-calculate-s2-normalizations.ipynb` Calculate per country Sentinel-2 normalization statistics. 35 | * `06-train.ipynb` Model training. 36 | * `07-predict.ipynb` Predict the model on smaller patchlets. 37 | * `07b-predict-eopatches.ipynb` Predict the model on whole EOPatches. 38 | 39 | ## Acknowledgments 40 | The code is adapted from [ElementAI's HighResNet code](https://github.com/ElementAI/HighRes-net). Refer also to the [published paper](https://arxiv.org/abs/2002.06460). 41 | -------------------------------------------------------------------------------- /hrnet/README.md: -------------------------------------------------------------------------------- 1 | The HighResNet model was adapted from [Element AI HIghRes-net repository](https://github.com/ElementAI/HighRes-net). -------------------------------------------------------------------------------- /hrnet/src/DeepNetworks/HRNet.py: -------------------------------------------------------------------------------- 1 | """ Pytorch implementation of HRNet, a neural network for multi-frame super resolution (MFSR) by recursive fusion. 2 | Credits: 3 | This code is adapted from ElementAI's HighRes-Net: https://github.com/ElementAI/HighRes-net 4 | """ 5 | 6 | import torch.nn as nn 7 | import torch 8 | 9 | 10 | class ResidualBlock(nn.Module): 11 | def __init__(self, channel_size=64, kernel_size=3): 12 | """ 13 | Args: 14 | channel_size : int, number of hidden channels 15 | kernel_size : int, shape of a 2D kernel 16 | """ 17 | 18 | super(ResidualBlock, self).__init__() 19 | padding = kernel_size // 2 20 | self.block = nn.Sequential( 21 | nn.Conv2d(in_channels=channel_size, out_channels=channel_size, kernel_size=kernel_size, padding=padding), 22 | nn.PReLU(), 23 | nn.Conv2d(in_channels=channel_size, out_channels=channel_size, kernel_size=kernel_size, padding=padding), 24 | nn.PReLU() 25 | ) 26 | 27 | def forward(self, x): 28 | """ 29 | Args: 30 | x : tensor (B, C, W, H), hidden state 31 | Returns: 32 | x + residual: tensor (B, C, W, H), new hidden state 33 | """ 34 | 35 | residual = self.block(x) 36 | return x + residual 37 | 38 | 39 | class Encoder(nn.Module): 40 | def __init__(self, config): 41 | """ 42 | Args: 43 | config : dict, configuration file 44 | """ 45 | 46 | super(Encoder, self).__init__() 47 | 48 | in_channels = config["in_channels"] 49 | num_layers = config["num_layers"] 50 | kernel_size = config["kernel_size"] 51 | channel_size = config["channel_size"] 52 | padding = kernel_size // 2 53 | 54 | self.init_layer = nn.Sequential( 55 | nn.Conv2d(in_channels=in_channels, out_channels=channel_size, kernel_size=kernel_size, padding=padding), 56 | nn.PReLU()) 57 | 58 | res_layers = [ResidualBlock(channel_size, kernel_size) for _ in range(num_layers)] 59 | self.res_layers = nn.Sequential(*res_layers) 60 | 61 | self.final = nn.Sequential( 62 | nn.Conv2d(in_channels=channel_size, out_channels=channel_size, kernel_size=kernel_size, padding=padding) 63 | ) 64 | 65 | def forward(self, x): 66 | """ 67 | Encodes an input tensor x. 68 | Args: 69 | x : tensor (B, C_in, W, H), input images 70 | Returns: 71 | out: tensor (B, C, W, H), hidden states 72 | """ 73 | 74 | x = self.init_layer(x) 75 | x = self.res_layers(x) 76 | x = self.final(x) 77 | return x 78 | 79 | 80 | class RecuversiveNet(nn.Module): 81 | 82 | def __init__(self, config): 83 | """ 84 | Args: 85 | config : dict, configuration file 86 | """ 87 | 88 | super(RecuversiveNet, self).__init__() 89 | 90 | self.input_channels = config["in_channels"] 91 | self.alpha_residual = config["alpha_residual"] 92 | kernel_size = config["kernel_size"] 93 | padding = kernel_size // 2 94 | 95 | self.fuse = nn.Sequential( 96 | ResidualBlock(2 * self.input_channels, kernel_size), 97 | nn.Conv2d(in_channels=2 * self.input_channels, out_channels=self.input_channels, 98 | kernel_size=kernel_size, padding=padding), 99 | nn.PReLU()) 100 | 101 | def forward(self, x, alphas): 102 | """ 103 | Fuses hidden states recursively. 104 | Args: 105 | x : tensor (B, L, C, W, H), hidden states 106 | alphas : tensor (B, L, 1, 1, 1), boolean indicator (0 if padded low-res view, 1 otherwise) 107 | Returns: 108 | out: tensor (B, C, W, H), fused hidden state 109 | """ 110 | 111 | batch_size, nviews, channels, width, heigth = x.shape 112 | parity = nviews % 2 113 | half_len = nviews // 2 114 | 115 | while half_len > 0: 116 | alice = x[:, :half_len] # first half hidden states (B, L/2, C, W, H) 117 | bob = x[:, half_len:nviews - parity] # second half hidden states (B, L/2, C, W, H) 118 | bob = torch.flip(bob, [1]) 119 | 120 | alice_and_bob = torch.cat([alice, bob], 2) # concat hidden states accross channels (B, L/2, 2*C, W, H) 121 | alice_and_bob = alice_and_bob.view(-1, 2 * channels, width, heigth) 122 | x = self.fuse(alice_and_bob) 123 | x = x.view(batch_size, half_len, channels, width, heigth) # new hidden states (B, L/2, C, W, H) 124 | 125 | if self.alpha_residual: # skip connect padded views (alphas_bob = 0) 126 | # TODO the whole thing is shady 127 | alphas_alice = alphas[:, :half_len] 128 | alphas_bob = alphas[:, half_len:nviews - parity] 129 | alphas_bob = torch.flip(alphas_bob, [1]) 130 | x = alice + alphas_bob * x 131 | alphas = alphas_alice 132 | 133 | nviews = half_len 134 | parity = nviews % 2 135 | half_len = nviews // 2 136 | 137 | return torch.mean(x, 1) 138 | 139 | 140 | class DecoderShuffle(nn.Module): 141 | def __init__(self, config): 142 | """ 143 | Args: 144 | config : dict, configuration file 145 | """ 146 | 147 | super(DecoderShuffle, self).__init__() 148 | 149 | self.conv = nn.Sequential(nn.Conv2d(in_channels=config["pixel_shuffle"]["in_channels"], 150 | out_channels=config["pixel_shuffle"]["out_channels"], 151 | kernel_size=config["pixel_shuffle"]["kernel_size"], 152 | stride=config["pixel_shuffle"]["stride"], 153 | padding=config["pixel_shuffle"]["kernel_size"]//2, 154 | padding_mode='reflect'), 155 | nn.PReLU()) 156 | 157 | self.shuffle = nn.PixelShuffle(config["pixel_shuffle"]["scale"]) 158 | 159 | def forward(self, x): 160 | """ 161 | Decodes a hidden state x. 162 | Args: 163 | x : tensor (B, C, W, H), hidden states 164 | Returns: 165 | out: tensor (B, C_out, 3*W, 3*H), fused hidden state 166 | """ 167 | 168 | x = self.conv(x) 169 | x = self.shuffle(x) 170 | return x 171 | 172 | 173 | class Decoder(nn.Module): 174 | def __init__(self, config): 175 | """ 176 | Args: 177 | config : dict, configuration file 178 | """ 179 | 180 | super(Decoder, self).__init__() 181 | 182 | self.deconv = nn.Sequential(nn.ConvTranspose2d(in_channels=config["deconv"]["in_channels"], 183 | out_channels=config["deconv"]["out_channels"], 184 | kernel_size=config["deconv"]["kernel_size"], 185 | stride=config["deconv"]["stride"], output_padding=1), 186 | nn.PReLU()) 187 | 188 | self.final = nn.Conv2d(in_channels=config["final"]["in_channels"], 189 | out_channels=config["final"]["out_channels"], 190 | kernel_size=config["final"]["kernel_size"], 191 | padding=config["final"]["kernel_size"] // 2) 192 | 193 | def forward(self, x): 194 | """ 195 | Decodes a hidden state x. 196 | Args: 197 | x : tensor (B, C, W, H), hidden states 198 | Returns: 199 | out: tensor (B, C_out, 3*W, 3*H), fused hidden state 200 | """ 201 | x = self.deconv(x) 202 | x = self.final(x) 203 | return x 204 | 205 | 206 | class HRNet(nn.Module): 207 | """ HRNet, a neural network for multi-frame super resolution (MFSR) by recursive fusion. """ 208 | 209 | def __init__(self, config): 210 | """ 211 | Args: 212 | config : dict, configuration file 213 | """ 214 | 215 | super(HRNet, self).__init__() 216 | self.encode = Encoder(config["encoder"]) 217 | self.fuse = RecuversiveNet(config["recursive"]) 218 | 219 | decoder_layers = config["decoder"].keys() 220 | assert ('pixel_shuffle' in decoder_layers) != (('deconv' in decoder_layers) or ('final' in decoder_layers)), \ 221 | 'Incorrect config for the decoder layer. Specified either `pixel_shuffle` or both `deconv` and `final`' 222 | decoder = DecoderShuffle if "pixel_shuffle" in decoder_layers else Decoder 223 | self.decode = decoder(config["decoder"]) 224 | 225 | def forward(self, lrs, alphas): 226 | """ 227 | Super resolves a batch of low-resolution images. 228 | Args: 229 | lrs : tensor (B, k, C, W, H), low-resolution images 230 | alphas : tensor (B, k), boolean indicator (0 if padded low-res view, 1 otherwise) 231 | Returns: 232 | srs: tensor (B, C_out, W, H), super-resolved images 233 | """ 234 | 235 | batch_size, seq_len, channels, heigth, width = lrs.shape 236 | alphas = alphas.view(-1, seq_len, 1, 1, 1) 237 | 238 | refs = [] 239 | for batch_sample, batch_sample_alphas in zip(lrs, alphas): 240 | filtered_batch_sample = batch_sample[(batch_sample_alphas == 1).squeeze(), ...] 241 | ref, _ = torch.median(filtered_batch_sample, 0, keepdim=True) 242 | refs.append(ref) 243 | refs = torch.unsqueeze(torch.cat(refs, 0), 1) 244 | 245 | refs = refs.repeat(1, seq_len, 1, 1, 1) 246 | 247 | stacked_input = torch.cat([lrs, refs], 2) # tensor (B, L, 2*C_in, W, H) 248 | 249 | stacked_input = stacked_input.view(batch_size * seq_len, channels*2, width, heigth) 250 | layer1 = self.encode(stacked_input) # encode input tensor 251 | layer1 = layer1.view(batch_size, seq_len, -1, width, heigth) # tensor (B, L, C, W, H) 252 | 253 | # fuse, upsample 254 | recursive_layer = self.fuse(layer1, alphas) # fuse hidden states (B, C, W, H) 255 | srs = self.decode(recursive_layer) # decode final hidden state (B, C_out, 3*W, 3*H) 256 | return srs 257 | -------------------------------------------------------------------------------- /hrnet/src/DeepNetworks/ShiftNet.py: -------------------------------------------------------------------------------- 1 | ''' Pytorch implementation of HomographyNet. 2 | Reference: https://arxiv.org/pdf/1606.03798.pdf and https://github.com/mazenmel/Deep-homography-estimation-Pytorch 3 | Currently supports translations (2 params) 4 | The network reads pair of images (tensor x: [B,2*C,W,H]) 5 | and outputs parametric transformations (tensor out: [B,n_params]). 6 | 7 | Credits: 8 | This code is adapted from ElementAI's HighRes-Net: https://github.com/ElementAI/HighRes-net 9 | ''' 10 | 11 | import torch 12 | import torch.nn as nn 13 | from hrnet.src import lanczos 14 | 15 | 16 | class ShiftNet(nn.Module): 17 | ''' ShiftNet, a neural network for sub-pixel registration and interpolation with lanczos kernel. ''' 18 | 19 | def __init__(self, in_channel=4, patch_size=128, num_filters=64, size_linear=1024): 20 | ''' 21 | Args: 22 | in_channel : int, number of input channels 23 | ''' 24 | 25 | dim_after_conv = patch_size // (2**3) # 2**number_of_maxpools (because you divide dimension by two after each maxpool) 26 | super(ShiftNet, self).__init__() 27 | 28 | self.layer1 = nn.Sequential(nn.Conv2d(2 * in_channel, num_filters, 3, padding=1), 29 | nn.BatchNorm2d(num_filters), 30 | nn.ReLU()) 31 | self.layer2 = nn.Sequential(nn.Conv2d(num_filters, num_filters, 3, padding=1), 32 | nn.BatchNorm2d(num_filters), 33 | nn.ReLU(), 34 | nn.MaxPool2d(2)) 35 | self.layer3 = nn.Sequential(nn.Conv2d(num_filters, num_filters, 3, padding=1), 36 | nn.BatchNorm2d(num_filters), 37 | nn.ReLU()) 38 | self.layer4 = nn.Sequential(nn.Conv2d(num_filters, num_filters, 3, padding=1), 39 | nn.BatchNorm2d(num_filters), 40 | nn.ReLU(), 41 | nn.MaxPool2d(2)) 42 | self.layer5 = nn.Sequential(nn.Conv2d(num_filters, num_filters*2, 3, padding=1), 43 | nn.BatchNorm2d(num_filters*2), 44 | nn.ReLU()) 45 | self.layer6 = nn.Sequential(nn.Conv2d(num_filters*2, num_filters*2, 3, padding=1), 46 | nn.BatchNorm2d(num_filters*2), 47 | nn.ReLU(), 48 | nn.MaxPool2d(2)) 49 | self.layer7 = nn.Sequential(nn.Conv2d(num_filters*2, num_filters*2, 3, padding=1), 50 | nn.BatchNorm2d(num_filters*2), 51 | nn.ReLU()) 52 | self.layer8 = nn.Sequential(nn.Conv2d(num_filters*2, num_filters*2, 3, padding=1), 53 | nn.BatchNorm2d(num_filters*2), 54 | nn.ReLU()) 55 | 56 | self.drop1 = nn.Dropout(p=0.5) 57 | 58 | self.fc1 = nn.Linear(num_filters*2 * dim_after_conv * dim_after_conv, size_linear) 59 | 60 | self.activ1 = nn.ReLU() 61 | self.fc2 = nn.Linear(size_linear, 2, bias=False) 62 | self.fc2.weight.data.zero_() # init the weights with the identity transformation 63 | 64 | def forward(self, x): 65 | ''' 66 | Registers pairs of images with sub-pixel shifts. 67 | Args: 68 | x : tensor (B, 2, C_in, H, W), input pairs of images 69 | Returns: 70 | out: tensor (B, 2), translation params 71 | ''' 72 | 73 | #print(f'shift net forward input shape {x.shape}') 74 | 75 | batch, nviews, c, h, w = x.shape 76 | 77 | x[:, 0] = x[:, 0] - torch.mean(x[:, 0], dim=(2, 3)).view(-1, c, 1, 1) 78 | x[:, 1] = x[:, 1] - torch.mean(x[:, 1], dim=(2, 3)).view(-1, c, 1, 1) 79 | 80 | x = x.view(batch, nviews*c, h, w) 81 | out = self.layer1(x) 82 | 83 | out = self.layer2(out) 84 | out = self.layer3(out) 85 | out = self.layer4(out) 86 | out = self.layer5(out) 87 | out = self.layer6(out) 88 | out = self.layer7(out) 89 | 90 | out = self.layer8(out) 91 | _, feats, dim, _ = out.shape 92 | 93 | out = out.view(-1, feats * dim * dim) 94 | 95 | out = self.drop1(out) # dropout on spatial tensor (C*W*H) 96 | 97 | out = self.fc1(out) 98 | 99 | out = self.activ1(out) 100 | out = self.fc2(out) 101 | return out 102 | 103 | def transform(self, theta, I, device="cpu"): 104 | ''' 105 | Shifts images I by theta with Lanczos interpolation. 106 | Args: 107 | theta : tensor (B, 2), translation params 108 | I : tensor (B, C_in, H, W), input images 109 | Returns: 110 | out: tensor (B, C_in, W, H), shifted images 111 | ''' 112 | 113 | self.theta = theta 114 | 115 | new_I = lanczos.lanczos_shift(img=I, 116 | shift=self.theta.flip(-1), # (dx, dy) from register_batch -> flip 117 | a=3, p=5)[:, None] 118 | #print(f'new I shejp: {new_I.shape}') 119 | return new_I -------------------------------------------------------------------------------- /hrnet/src/DeepNetworks/__init__.py: -------------------------------------------------------------------------------- 1 | The HighResNet model was adapted from [https://github.com/ElementAI/HighRes-net](Element AI HIghRes-net repository). -------------------------------------------------------------------------------- /hrnet/src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sentinel-hub/multi-temporal-super-resolution/5ef642304a980db87bdb935a7a7450bd649f8912/hrnet/src/__init__.py -------------------------------------------------------------------------------- /hrnet/src/lanczos.py: -------------------------------------------------------------------------------- 1 | """ Python modules for Lanczos interpolation. 2 | 3 | Credits: 4 | This code is adapted from ElementAI's HighRes-Net: https://github.com/ElementAI/HighRes-net 5 | """ 6 | 7 | 8 | import torch 9 | import numpy as np 10 | 11 | 12 | def lanczos_kernel(dx, a=3, N=None, dtype=None, device=None): 13 | ''' 14 | Generates 1D Lanczos kernels for translation and interpolation. 15 | Args: 16 | dx : float, tensor (batch_size, 1), the translation in pixels to shift an image. 17 | a : int, number of lobes in the kernel support. 18 | If N is None, then the width is the kernel support (length of all lobes), 19 | S = 2(a + ceil(dx)) + 1. 20 | N : int, width of the kernel. 21 | If smaller than S then N is set to S. 22 | Returns: 23 | k: tensor (?, ?), lanczos kernel 24 | ''' 25 | 26 | if not torch.is_tensor(dx): 27 | dx = torch.tensor(dx, dtype=dtype, device=device) 28 | 29 | if device is None: 30 | device = dx.device 31 | 32 | if dtype is None: 33 | dtype = dx.dtype 34 | 35 | D = dx.abs().ceil().int() 36 | S = 2 * (a + D) + 1 # width of kernel support 37 | 38 | S_max = S.max() if hasattr(S, 'shape') else S 39 | 40 | if (N is None) or (N < S_max): 41 | N = S 42 | 43 | Z = (N - S) // 2 # width of zeros beyond kernel support 44 | 45 | start = (-(a + D + Z)).min() 46 | end = (a + D + Z + 1).max() 47 | x = torch.arange(start, end, dtype=dtype, device=device).view(1, -1) - dx 48 | px = (np.pi * x) + 1e-3 49 | 50 | sin_px = torch.sin(px) 51 | sin_pxa = torch.sin(px / a) 52 | 53 | k = a * sin_px * sin_pxa / px**2 # sinc(x) masked by sinc(x/a) 54 | 55 | return k 56 | 57 | 58 | def lanczos_shift(img, shift, p=3, a=3): 59 | ''' 60 | Shifts an image by convolving it with a Lanczos kernel. 61 | Lanczos interpolation is an approximation to ideal sinc interpolation, 62 | by windowing a sinc kernel with another sinc function extending up to a 63 | few nunber of its lobes (typically a=3). 64 | 65 | Args: 66 | img : tensor (batch_size, channels, height, width), the images to be shifted 67 | shift : tensor (batch_size, 2) of translation parameters (dy, dx) 68 | p : int, padding width prior to convolution (default=3) 69 | a : int, number of lobes in the Lanczos interpolation kernel (default=3) 70 | Returns: 71 | I_s: tensor (batch_size, channels, height, width), shifted images 72 | ''' 73 | dtype = img.dtype 74 | 75 | 76 | 77 | _, channels, _, _ = img.shape 78 | 79 | if len(img.shape) == 2: 80 | img = img[None, None].repeat(1, shift.shape[0], 1, 1) # batch of one image 81 | elif len(img.shape) == 3: # one image per shift 82 | assert img.shape[0] == shift.shape[0] 83 | img = img[None, ] 84 | 85 | 86 | # Apply padding 87 | 88 | padder = torch.nn.ReflectionPad2d(p) # reflect pre-padding 89 | I_padded = padder(img) 90 | 91 | # Create 1D shifting kernels 92 | 93 | y_shift = shift[:, [0]] 94 | x_shift = shift[:, [1]] 95 | 96 | k_y = (lanczos_kernel(y_shift, a=a, N=None, dtype=dtype) 97 | .flip(1) # flip axis of convolution 98 | )[:, None, :, None] # expand dims to get shape (batch, channels, y_kernel, 1) 99 | 100 | k_x = (lanczos_kernel(x_shift, a=a, N=None, dtype=dtype) 101 | .flip(1) 102 | )[:, None, None, :] # shape (batch, channels, 1, x_kernel) 103 | 104 | I_s = torch.conv1d(I_padded.permute(1, 0, 2, 3), 105 | groups=k_y.shape[0], 106 | weight=k_y, 107 | padding=[k_y.shape[2] // 2, 0]) # same padding 108 | I_s = torch.conv1d(I_s, 109 | groups=k_x.shape[0], 110 | weight=k_x, 111 | padding=[0, k_x.shape[3] // 2]) 112 | 113 | I_s = I_s[..., p:-p, p:-p] # remove padding 114 | I_s = I_s.permute(1, 0, 2, 3) 115 | 116 | return I_s.squeeze() # , k.squeeze() 117 | -------------------------------------------------------------------------------- /hrnet/src/predict.py: -------------------------------------------------------------------------------- 1 | """ 2 | Credits: 3 | This code is adapted from ElementAI's HighRes-Net: https://github.com/ElementAI/HighRes-net 4 | """ 5 | import torch 6 | from hrnet.src.DeepNetworks.HRNet import HRNet 7 | 8 | def get_sr(sample, model): 9 | """ 10 | Super resolves an imset with a given model. 11 | Args: 12 | sample: imset sample 13 | model: HRNet, pytorch model 14 | Returns: 15 | sr: tensor (1, C_out, W, H), super resolved image 16 | """ 17 | 18 | lrs, alphas, names = sample['lr'], sample['alphas'], sample['name'] 19 | 20 | if lrs.ndim == 4: 21 | nviews, c, h, w = lrs.shape 22 | lrs = lrs.view(1, nviews, c, h, w) 23 | alphas = alphas.view(1, nviews) 24 | 25 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 26 | lrs = lrs.float().to(device) 27 | alphas = alphas.float().to(device) 28 | sr = model(lrs, alphas) 29 | sr = sr.detach().cpu().numpy() 30 | 31 | return sr 32 | 33 | 34 | def load_model(config, checkpoint_file): 35 | """ 36 | Loads a pretrained model from disk. 37 | Args: 38 | config: dict, configuration file 39 | checkpoint_file: str, checkpoint filename 40 | Returns: 41 | model: HRNet, a pytorch model 42 | """ 43 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 44 | model = HRNet(config["network"]).to(device) 45 | model.load_state_dict(torch.load(checkpoint_file, map_location=device)) 46 | return model 47 | 48 | class Model(object): 49 | 50 | def __init__(self, config): 51 | self.config = config 52 | self.model = None 53 | 54 | def load_checkpoint(self, checkpoint_file): 55 | self.model = load_model(self.config, checkpoint_file) 56 | 57 | def __call__(self, sample): 58 | sr = get_sr(sample, self.model) 59 | return sr -------------------------------------------------------------------------------- /hrnet/src/utils.py: -------------------------------------------------------------------------------- 1 | """ Python utilities 2 | 3 | Credits: 4 | This code is adapted from ElementAI's HighRes-Net: https://github.com/ElementAI/HighRes-net 5 | """ 6 | import numpy as np 7 | from matplotlib.backends.backend_agg import FigureCanvasAgg 8 | from matplotlib.figure import Figure 9 | 10 | 11 | def distributions_plot(s2_values: np.ndarray, deimos_values: np.ndarray, sr_values: np.ndarray, band: int ) -> np.ndarray: 12 | """ 13 | Return plot image histogram of S2, Deimos and SR distributions for a particular band. 14 | 15 | s2_values: np.ndarray: NxCxWxH array of S2 values 16 | deimos_values: np.ndarray: NxCxWxH array of deimos values 17 | sr_values: np.ndarray: NxCxWxH array of predicted super resolved images 18 | 19 | return: np.ndarray: matplotlib plot of the distributions converted to numpy array image (so it can be passed to WANDB). 20 | """ 21 | # make a Figure and attach it to a canvas. 22 | fig = Figure(figsize=(15, 7), dpi=300) 23 | canvas = FigureCanvasAgg(fig) 24 | 25 | # Do some plotting here 26 | ax = fig.add_subplot(1, 1, 1) 27 | ax.hist(s2_values[:, band, ...].flatten(), alpha=.33, bins=100, range=(-2, 2), label='S2', density=True) 28 | ax.hist(deimos_values[:, band, ...].flatten(), alpha=.33, bins=100, range=(-2, 2), label='Deimos', density=True) 29 | ax.hist(sr_values[:, band, ...].flatten(), alpha=.33, bins=100, label='SR', range=(-2, 2), density=True, histtype='step') 30 | ax.set_title(f'Band {band}') 31 | ax.legend() 32 | 33 | # Retrieve a view on the renderer buffer 34 | canvas.draw() 35 | buf = canvas.buffer_rgba() 36 | # convert to a NumPy array 37 | X = np.asarray(buf) 38 | return X 39 | 40 | 41 | 42 | def normalize_plotting(rgb): 43 | """ 44 | Rescales the data between 0 and 1. 45 | 46 | rgb: np.ndarray: CxHxW array of bands. 47 | """ 48 | n_channels = len(rgb) 49 | rgb = np.moveaxis(rgb, 0, 2) 50 | min_rgb = np.ones(n_channels)*(-1) 51 | max_rgb = np.ones(n_channels) 52 | rgb_0_1 = (rgb - min_rgb) / (np.abs(min_rgb) + max_rgb) 53 | return rgb_0_1 54 | -------------------------------------------------------------------------------- /input/TestAreaCyprus.geojson: -------------------------------------------------------------------------------- 1 | { 2 | "type": "FeatureCollection", 3 | "name": "TestAreaCyprus", 4 | "crs": { "type": "name", "properties": { "name": "urn:ogc:def:crs:OGC:1.3:CRS84" } }, 5 | "features": [ 6 | { "type": "Feature", "properties": { "OBJECTID": 6426, "ID": 0, "PARCELID": "Nelia19", "APPLID": null, "CROPM": null, "AREAM": 0.0, "PERIMETERM": 0.0, "TOLCOEF": 0.0, "AREATOL": 1000.0, "MEASDATE": null, "INSPECTOR": null, "DATE_TAKEN": "2019\/05\/04", "SHAPE_AREA": 324104417.449, "SHAPE_LEN": 93341.060058300005, "area": 227 }, "geometry": { "type": "MultiPolygon", "coordinates": [ [ [ [ 33.13408361433774, 35.049465375084011 ], [ 33.006123833925791, 35.051924032126529 ], [ 33.005994979934442, 35.052949400699539 ], [ 33.005671003516682, 35.060185148443459 ], [ 32.979995413375647, 35.062364047120887 ], [ 32.967438860831386, 35.066497116560122 ], [ 32.961677365685887, 35.07206909512302 ], [ 32.953356039438766, 35.071098089685968 ], [ 32.944708767187528, 35.059285948499408 ], [ 32.925389768939525, 35.0556870410784 ], [ 32.916850334231491, 35.053442469257782 ], [ 32.900254524406918, 35.051589669089168 ], [ 32.900258264931878, 35.017820250020854 ], [ 32.874868695913655, 35.017796729151684 ], [ 32.875925807809544, 35.058406967250022 ], [ 32.875139716858342, 35.063865655578908 ], [ 32.875117071823972, 35.075091389940546 ], [ 32.874721619147088, 35.089721211553702 ], [ 32.874188759221916, 35.105788180974429 ], [ 32.892944228956743, 35.108601495528745 ], [ 32.911651280175704, 35.101500844287322 ], [ 32.923299271326101, 35.097368151106778 ], [ 32.933435003876937, 35.099150702850544 ], [ 32.959046774232398, 35.105605394260266 ], [ 32.971658772078321, 35.114564996016895 ], [ 32.980989754313427, 35.117725951584234 ], [ 32.989461802966936, 35.129197997174728 ], [ 33.014349861111775, 35.153218774629401 ], [ 33.016869989106574, 35.155206342020549 ], [ 33.03401450969379, 35.162349079243121 ], [ 33.042942616770397, 35.165720048058489 ], [ 33.063536993526505, 35.171401442016624 ], [ 33.081066296251898, 35.172040140783778 ], [ 33.097580224624302, 35.169466675553693 ], [ 33.124933140695411, 35.170171863066017 ], [ 33.134270731234793, 35.170334313148913 ], [ 33.135057712766738, 35.170319721210944 ], [ 33.13408361433774, 35.049465375084011 ] ] ] ] } }, 7 | { "type": "Feature", "properties": { "OBJECTID": 6427, "ID": 0, "PARCELID": "Famia19", "APPLID": null, "CROPM": null, "AREAM": 0.0, "PERIMETERM": 0.0, "TOLCOEF": 0.0, "AREATOL": 1000.0, "MEASDATE": null, "INSPECTOR": null, "DATE_TAKEN": "2019\/04\/23", "SHAPE_AREA": 272795008.50099999, "SHAPE_LEN": 81492.862667399997, "area": 156 }, "geometry": { "type": "MultiPolygon", "coordinates": [ [ [ [ 33.926034338799496, 35.069101158815542 ], [ 33.927168262119764, 34.974517068477383 ], [ 33.904310674014198, 34.969562139676036 ], [ 33.893822415472989, 34.953576360206682 ], [ 33.874802964951883, 34.94719028871171 ], [ 33.857104321599223, 34.949543202053171 ], [ 33.851431199694872, 34.952008656723763 ], [ 33.803787281556303, 34.974163335941178 ], [ 33.803556742527967, 35.000010900353701 ], [ 33.804326161835242, 35.064810231998024 ], [ 33.82019029007548, 35.073619808806463 ], [ 33.825836764160186, 35.072388615043486 ], [ 33.839973806345903, 35.064339027864627 ], [ 33.845240671101521, 35.063904813931629 ], [ 33.850816350047424, 35.065879496374173 ], [ 33.856211724523448, 35.067883632941928 ], [ 33.871427777750675, 35.083100329148962 ], [ 33.867144026844976, 35.100092944649404 ], [ 33.87243854998033, 35.124653925907026 ], [ 33.879125264987223, 35.125422812378631 ], [ 33.88696876938608, 35.124018144904262 ], [ 33.902057244995177, 35.104590705720746 ], [ 33.919097248801954, 35.08731431911599 ], [ 33.906815544005077, 35.073347479404532 ], [ 33.9159993498305, 35.068295564738953 ], [ 33.926034338799496, 35.069101158815542 ] ] ] ] } }, 8 | { "type": "Feature", "properties": { "OBJECTID": 6425, "ID": 0, "PARCELID": "Pefia19", "APPLID": null, "CROPM": null, "AREAM": 0.0, "PERIMETERM": 0.0, "TOLCOEF": 0.0, "AREATOL": 1000.0, "MEASDATE": null, "INSPECTOR": null, "DATE_TAKEN": "2019\/05\/12", "SHAPE_AREA": 200176701.97099999, "SHAPE_LEN": 58967.8934973, "area": 117 }, "geometry": { "type": "MultiPolygon", "coordinates": [ [ [ [ 32.570879367763091, 34.885940183049399 ], [ 32.441188841042631, 34.885413220487592 ], [ 32.440585394286423, 34.974357059492164 ], [ 32.570857331615741, 34.974344086644429 ], [ 32.570872972208875, 34.91204193984705 ], [ 32.570879367763091, 34.885940183049399 ] ] ] ] } } 9 | ] 10 | } 11 | -------------------------------------------------------------------------------- /input/TestAreaLithuania.geojson: -------------------------------------------------------------------------------- 1 | { 2 | "type": "FeatureCollection", 3 | "name": "TestAreaLithuania", 4 | "crs": { "type": "name", "properties": { "name": "urn:ogc:def:crs:OGC:1.3:CRS84" } }, 5 | "features": [ 6 | { "type": "Feature", "properties": { "OBJECTID": 1, "SHAPE_Leng": 110262.89625200001, "SHAPE_Area": 687673950.51800001, "area_km": 2600 }, "geometry": { "type": "MultiPolygon", "coordinates": [ [ [ [ 21.613652708531841, 55.631815908542102 ], [ 23.230171728490621, 55.655015941583898 ], [ 23.235248332025229, 55.422223710388444 ], [ 21.628685662804344, 55.4058565378606 ], [ 21.613652708531841, 55.631815908542102 ] ] ] ] } } 7 | ] 8 | } 9 | -------------------------------------------------------------------------------- /input/config-local-hrn-deconv.json: -------------------------------------------------------------------------------- 1 | { 2 | "paths": { 3 | "prefix": "/home/ubuntu/npz-unpacked-small-2.5m/", 4 | "checkpoint_dir": "models/weights", 5 | "scores_dir": "models/scores", 6 | "tb_log_file_dir": "tb_logs/" 7 | }, 8 | 9 | "network": { 10 | "upscale_factor": 4, 11 | "encoder": { 12 | "in_channels": 8, 13 | "num_layers" : 8, 14 | "kernel_size": 3, 15 | "channel_size": 64 16 | }, 17 | "recursive": { 18 | "alpha_residual": true, 19 | "in_channels": 64, 20 | "kernel_size": 3 21 | }, 22 | "decoder": { 23 | "final": { 24 | "in_channels": 64, 25 | "kernel_size": 1, 26 | "out_channels": 4 27 | }, 28 | "deconv": { 29 | "stride": 4, 30 | "in_channels": 64, 31 | "kernel_size": 3, 32 | "out_channels": 64 33 | } 34 | } 35 | }, 36 | 37 | "training": { 38 | "num_epochs": 50, 39 | "validation_metrics": ["SSIM", "PSNR", "MSE", "MIXED", "MAE"], 40 | "histogram_matching": false, 41 | "loss_metric": "MIXED", 42 | "apply_correction": true, 43 | "augment": true, 44 | "use_reg_regularization": true, 45 | "lambda": 0.001, 46 | "use_kl_div_loss": true, 47 | "eta": 10, 48 | "wandb": true, 49 | "use_gpu": true, 50 | "batch_size": 128, 51 | "n_views": 8, 52 | "n_workers": 8, 53 | "reg_offset": 16, 54 | "lr": 0.0007, 55 | "patch_size": 32, 56 | "seed": 0, 57 | "channels_features": [0, 1, 2, 3], 58 | "channels_labels": [0, 1, 2, 3] 59 | }, 60 | "visualization": { 61 | "channels_to_plot": [2, 1, 0], 62 | "distribution_sampling_proba": 0.05 63 | 64 | }, 65 | "perceptual_loss": { 66 | "model_name": "NivaModelV2", 67 | "weight": 0.05 68 | } 69 | } 70 | -------------------------------------------------------------------------------- /input/config-local-hrn-pix-shu.json: -------------------------------------------------------------------------------- 1 | { 2 | "paths": { 3 | "prefix": "/home/ubuntu/npz-unpacked-small-2.5m/", 4 | "checkpoint_dir": "models/weights", 5 | "scores_dir": "models/scores", 6 | "tb_log_file_dir": "tb_logs/" 7 | }, 8 | 9 | "network": { 10 | "upscale_factor": 4, 11 | "encoder": { 12 | "in_channels": 8, 13 | "num_layers" : 8, 14 | "kernel_size": 3, 15 | "channel_size": 64 16 | }, 17 | "recursive": { 18 | "alpha_residual": true, 19 | "in_channels": 64, 20 | "kernel_size": 3 21 | }, 22 | "decoder": { 23 | "pixel_shuffle": { 24 | "in_channels": 64, 25 | "kernel_size": 3, 26 | "out_channels": 64, 27 | "stride": 1, 28 | "scale": 4 29 | } 30 | } 31 | }, 32 | 33 | "training": { 34 | "num_epochs": 50, 35 | "validation_metrics": ["SSIM", "PSNR", "MSE", "MIXED", "MAE"], 36 | "histogram_matching": false, 37 | "loss_metric": "MSE", 38 | "apply_correction": true, 39 | "augment": false, 40 | "use_reg_regularization": true, 41 | "lambda": 0.001, 42 | "use_kl_div_loss": true, 43 | "eta": 10, 44 | "wandb": true, 45 | "use_gpu": true, 46 | "batch_size": 32, 47 | "n_views": 8, 48 | "n_workers": 8, 49 | "reg_offset": 16, 50 | "lr": 0.0007, 51 | "patch_size": 32, 52 | "seed": 0, 53 | "channels_features": [0, 1, 2, 3], 54 | "channels_labels": [0, 1, 2, 3] 55 | }, 56 | "visualization": { 57 | "channels_to_plot": [2, 1, 0], 58 | "distribution_sampling_proba": 0.05 59 | 60 | }, 61 | "perceptual_loss": { 62 | "model_name": "NivaModelV2", 63 | "weight": 0.05 64 | } 65 | } 66 | -------------------------------------------------------------------------------- /notebooks/00-parse-deimos-metadata.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Read DEIMOS metadata from XML file\n", 8 | "\n", 9 | "Extract relevant information about the DEIMOS bands into dataframes.\n", 10 | "\n", 11 | "**NOTE**: DEIMOS bands are provided as `NIR-R-G-B`, while we store them in `EOPatches` as `B-G-R-NIR` as in Sentinel-2 datasets. This means that we will have to swap the info read from XML files in `split_per_band`." 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": null, 17 | "metadata": { 18 | "lines_to_end_of_cell_marker": 2 19 | }, 20 | "outputs": [], 21 | "source": [ 22 | "import os\n", 23 | "from xml.etree import ElementTree as ET\n", 24 | "\n", 25 | "import pandas as pd\n", 26 | "from fs_s3fs import S3FS" 27 | ] 28 | }, 29 | { 30 | "cell_type": "markdown", 31 | "metadata": {}, 32 | "source": [ 33 | "## Config" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": null, 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "instance_id = ''\n", 43 | "aws_access_key_id = ''\n", 44 | "aws_secret_access_key = ''" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": null, 50 | "metadata": {}, 51 | "outputs": [], 52 | "source": [ 53 | "filesystem = S3FS(bucket_name='',\n", 54 | " aws_access_key_id=aws_access_key_id,\n", 55 | " aws_secret_access_key=aws_secret_access_key)" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": null, 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "tiles_folder = ''" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": null, 70 | "metadata": {}, 71 | "outputs": [], 72 | "source": [ 73 | "def tag_parser(el_iterator, vals_dict, attr='text', attrib_key=None):\n", 74 | " for sub in el_iterator:\n", 75 | " if attr == 'attrib':\n", 76 | " vals_dict[sub.tag] = getattr(sub, attr)[attrib_key]\n", 77 | " else:\n", 78 | " vals_dict[sub.tag] = getattr(sub, attr)\n", 79 | "\n", 80 | "\n", 81 | "def multitag_parser(el_iterator, vals_dict, attr='text'):\n", 82 | " children = []\n", 83 | " for sub in el_iterator:\n", 84 | " tag_name = sub.tag\n", 85 | " x = {}\n", 86 | " tag_parser(sub.getchildren(), x)\n", 87 | " children.append(x)\n", 88 | " vals_dict[tag_name] = children\n", 89 | "\n", 90 | "\n", 91 | "def parse_bbox(el_iterator, vals_dict, outname, use_xy=True):\n", 92 | " appendices = ['X', 'Y'] if use_xy else ['LAT', 'LON']\n", 93 | " vertex_dict = {f'FRAME_{appendix}': [] for appendix in appendices}\n", 94 | " for vertex in el_iterator:\n", 95 | " for appendix in appendices:\n", 96 | " vertex_dict[f'FRAME_{appendix}'].append(vertex.find(f'./FRAME_{appendix}').text)\n", 97 | "\n", 98 | " if use_xy:\n", 99 | " vals_dict[outname] = [min(vertex_dict['FRAME_X']), min(vertex_dict['FRAME_Y']),\n", 100 | " max(vertex_dict['FRAME_X']), max(vertex_dict['FRAME_Y'])]\n", 101 | " else:\n", 102 | " vals_dict[outname] = [min(vertex_dict['FRAME_LAT']), min(vertex_dict['FRAME_LON']),\n", 103 | " max(vertex_dict['FRAME_LAT']), max(vertex_dict['FRAME_LON'])]\n", 104 | "\n", 105 | "\n", 106 | "def split_per_band(columns, column, query_keys, revert_bands=True,\n", 107 | " index_col='BAND_INDEX', n_bands=4):\n", 108 | "\n", 109 | " for valdict in columns[column]:\n", 110 | " if all([key in set(valdict.keys()) for key in query_keys]):\n", 111 | " for key in query_keys:\n", 112 | " idx = int(valdict[index_col])\n", 113 | " if revert_bands:\n", 114 | " idx = n_bands-idx+1\n", 115 | " columns[f'{key}_{idx}'] = valdict[key]\n", 116 | " columns.pop(column, None)\n", 117 | "\n", 118 | "\n", 119 | "def parse_deimos_metadata_file(metadata_file, filesystem):\n", 120 | " tree = ET.parse(filesystem.open(metadata_file))\n", 121 | " root = tree.getroot()\n", 122 | " columns = {}\n", 123 | " tag_parser(root.findall('./Dataset_Id/'), columns)\n", 124 | " tag_parser(root.findall('./Production/'), columns)\n", 125 | " tag_parser(root.findall('./Data_Processing/'), columns)\n", 126 | " tag_parser(root.findall('./Raster_CS/'), columns)\n", 127 | " parse_bbox(root.findall('./Dataset_Frame/'), columns, 'bbox')\n", 128 | " tag_parser(root.findall('./Raster_Encoding/'), columns)\n", 129 | " tag_parser(root.findall('./Data_Access/'), columns)\n", 130 | " tag_parser(root.findall('./Data_Access/Data_File/'), columns, attr='attrib', attrib_key='href')\n", 131 | " tag_parser(root.findall('./Raster_Dimensions/'), columns)\n", 132 | " multitag_parser(root.findall('./Image_Interpretation/'), columns)\n", 133 | " multitag_parser(root.findall('./Image_Display/'), columns)\n", 134 | " tag_parser(root.findall('./Dataset_Sources/Source_Information/Coordinate_Reference_System/'), columns)\n", 135 | " tag_parser(root.findall('./Dataset_Sources/Source_Information/Scene_Source/'), columns)\n", 136 | " multitag_parser(root.findall('./Dataset_Sources/Source_Information/Quality_Assessment/'), columns)\n", 137 | " parse_bbox(root.findall('./Dataset_Sources/Source_Information/Source_Frame/'),\n", 138 | " columns, 'source_frame_bbox_latlon', use_xy=False)\n", 139 | "\n", 140 | " split_per_band(columns,\n", 141 | " 'Band_Statistics',\n", 142 | " ['STX_STDV', 'STX_MEAN', 'STX_MIN', 'STX_MAX'])\n", 143 | " split_per_band(columns,\n", 144 | " 'Spectral_Band_Info',\n", 145 | " ['PHYSICAL_GAIN', 'PHYSICAL_BIAS', 'PHYSICAL_UNIT', 'ESUN'])\n", 146 | " return pd.DataFrame([columns])" 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": null, 152 | "metadata": {}, 153 | "outputs": [], 154 | "source": [ 155 | "ms4_dfs = []\n", 156 | "pan_dfs = []\n", 157 | "\n", 158 | "tiles = filesystem.listdir(tiles_folder)\n", 159 | "\n", 160 | "for tile in tiles:\n", 161 | " # this is needed because folder was copied from somewhere else\n", 162 | " if not filesystem.exists(f'{tiles_folder}/{tile}'):\n", 163 | " filesystem.makedirs(f'{tiles_folder}/{tile}')\n", 164 | "\n", 165 | " metadata = filesystem.listdir(f'{tiles_folder}/{tile}')\n", 166 | " metadata = [meta for meta in metadata if os.path.splitext(meta)[-1] == '.dim']\n", 167 | "\n", 168 | " metadata_file_ms4 = metadata[0] if '_MS4_' in metadata[0] else metadata[1]\n", 169 | " metadata_file_pan = metadata[0] if '_PAN_' in metadata[0] else metadata[1]\n", 170 | "\n", 171 | " ms4_dfs.append(parse_deimos_metadata_file(f'{tiles_folder}/{tile}/{metadata_file_ms4}',\n", 172 | " filesystem))\n", 173 | " pan_dfs.append(parse_deimos_metadata_file(f'{tiles_folder}/{tile}/{metadata_file_pan}',\n", 174 | " filesystem))\n", 175 | "\n", 176 | "ms4_metadata = pd.concat(ms4_dfs)\n", 177 | "pan_metadata = pd.concat(pan_dfs)" 178 | ] 179 | }, 180 | { 181 | "cell_type": "code", 182 | "execution_count": null, 183 | "metadata": {}, 184 | "outputs": [], 185 | "source": [ 186 | "ms4_metadata.to_parquet(filesystem.openbin('metadata/deimos_ms4_metadata.pq', 'wb'))\n", 187 | "pan_metadata.to_parquet(filesystem.openbin('metadata/deimos_pan_metadata.pq', 'wb'))" 188 | ] 189 | } 190 | ], 191 | "metadata": { 192 | "kernelspec": { 193 | "display_name": "Python 3", 194 | "language": "python", 195 | "name": "python3" 196 | } 197 | }, 198 | "nbformat": 4, 199 | "nbformat_minor": 2 200 | } -------------------------------------------------------------------------------- /notebooks/00a-add-per-tile-median.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from datetime import datetime" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": { 16 | "lines_to_next_cell": 2 17 | }, 18 | "outputs": [], 19 | "source": [ 20 | "import numpy as np\n", 21 | "import pandas as pd\n", 22 | "from eolearn.core import FeatureType\n", 23 | "from eolearn.io import ImportFromTiff\n", 24 | "from fs_s3fs import S3FS\n", 25 | "from sentinelhub import SHConfig\n", 26 | "from tqdm.auto import tqdm" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": null, 32 | "metadata": {}, 33 | "outputs": [], 34 | "source": [ 35 | "AWS_ACCESS_KEY_ID = ''\n", 36 | "AWS_SECRET_ACCESS_KEY = ''\n", 37 | "BUCKET_NAME = ''\n", 38 | "LOC_ON_BUCKET = ''" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": null, 44 | "metadata": {}, 45 | "outputs": [], 46 | "source": [ 47 | "filesystem = S3FS(bucket_name=BUCKET_NAME,\n", 48 | " aws_access_key_id=AWS_ACCESS_KEY_ID,\n", 49 | " aws_secret_access_key=AW)" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": null, 55 | "metadata": {}, 56 | "outputs": [], 57 | "source": [ 58 | "config = SHConfig()\n", 59 | "config.aws_access_key_id = AWS_ACCESS_KEY_ID\n", 60 | "config.aws_secret_access_key = AWS_SECRET_ACCESS_KEY" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": null, 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [ 69 | "metadata_ms4 = pd.read_parquet(filesystem.openbin('metadata/deimos_ms4_metadata.pq'))\n", 70 | "metadata_pan = pd.read_parquet(filesystem.openbin('metadata/deimos_pan_metadata.pq'))" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": null, 76 | "metadata": {}, 77 | "outputs": [], 78 | "source": [ 79 | "medians_ms4 = []\n", 80 | "medians_pan = []\n", 81 | "for i, sensing_time in enumerate(tqdm(filesystem.listdir(LOC_ON_BUCKET))):\n", 82 | " try:\n", 83 | " folder = f's3://{BUCKET_NAME}/{LOC_ON_BUCKET}/{sensing_time}'\n", 84 | " eop_ms4 = ImportFromTiff((FeatureType.DATA, 'MS4'), folder=folder, config=config).execute(\n", 85 | " filename=['B04.tiff', 'B03.tiff', 'B02.tiff', 'B01.tiff'])\n", 86 | " eop = ImportFromTiff((FeatureType.DATA, 'PAN'), folder=folder,\n", 87 | " config=config).execute(eop_ms4, filename='PAN.tiff')\n", 88 | " eop.timestamp = [datetime.strptime(sensing_time, '%Y-%m-%d_%H-%M-%S')]\n", 89 | "\n", 90 | " data = eop.data['MS4']\n", 91 | " mask = data[..., 0] > 0\n", 92 | " data_masked = data[mask, :]\n", 93 | " median_ms4 = np.median(data_masked, axis=0)\n", 94 | "\n", 95 | " medians_ms4.append({'sensing_time': eop.timestamp,\n", 96 | " 'STX_MEDIAN_1': median_ms4[0], 'STX_MEDIAN_2': median_ms4[1],\n", 97 | " 'STX_MEDIAN_3': median_ms4[2], 'STX_MEDIAN_4': median_ms4[3]})\n", 98 | "\n", 99 | " data = eop.data['PAN']\n", 100 | " mask = data[..., 0] > 0\n", 101 | " data_masked = data[mask, :]\n", 102 | " median_pan = np.median(data_masked, axis=0)\n", 103 | " medians_pan.append({'sensing_time': eop.timestamp,\n", 104 | " 'STX_MEDIAN_1': median_pan[0]})\n", 105 | "\n", 106 | " except Exception as e:\n", 107 | " print(e)\n", 108 | " print(f'Failed to proces sensing time: {sensing_time}')" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": null, 114 | "metadata": {}, 115 | "outputs": [], 116 | "source": [ 117 | "ms4 = pd.DataFrame(medians_ms4)\n", 118 | "ms4.sensing_time = ms4.sensing_time.apply(lambda x: x[0])\n", 119 | "ms4.sensing_time = ms4.sensing_time.apply(lambda x: str(x).replace(' ', 'T'))\n", 120 | "metadata_ms4_median = metadata_ms4.set_index('START_TIME').join(ms4.set_index('sensing_time')).reset_index()\n", 121 | "\n", 122 | "with filesystem.openbin('metadata/deimos_ms4_metadata.pq', 'wb') as f:\n", 123 | " metadata_ms4_median.to_parquet(f)" 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": null, 129 | "metadata": {}, 130 | "outputs": [], 131 | "source": [ 132 | "pan = pd.DataFrame(medians_pan)\n", 133 | "pan.sensing_time = pan.sensing_time.apply(lambda x: x[0])\n", 134 | "pan.sensing_time = pan.sensing_time.apply(lambda x: str(x).replace(' ', 'T'))\n", 135 | "pan_median = metadata_pan.set_index('START_TIME').join(pan.set_index('sensing_time')).reset_index()\n", 136 | "\n", 137 | "with filesystem.openbin('metadata/deimos_pan_metadata.pq', 'wb') as f:\n", 138 | " pan_median.to_parquet(f)" 139 | ] 140 | } 141 | ], 142 | "metadata": { 143 | "kernelspec": { 144 | "display_name": "Environment (conda_tensorflow2_p36)", 145 | "language": "python", 146 | "name": "conda_tensorflow2_p36" 147 | } 148 | }, 149 | "nbformat": 4, 150 | "nbformat_minor": 2 151 | } -------------------------------------------------------------------------------- /notebooks/00b-calculate-cloudfree-deimos-stats.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os\n", 10 | "from datetime import datetime" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "metadata": { 17 | "lines_to_next_cell": 2 18 | }, 19 | "outputs": [], 20 | "source": [ 21 | "import numpy as np\n", 22 | "import pandas as pd\n", 23 | "from eolearn.core import FeatureType\n", 24 | "from eolearn.io import ImportFromTiff\n", 25 | "from fs_s3fs import S3FS\n", 26 | "from sentinelhub import SHConfig\n", 27 | "from tqdm.auto import tqdm" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": null, 33 | "metadata": {}, 34 | "outputs": [], 35 | "source": [ 36 | "AWS_ACCESS_KEY_ID = ''\n", 37 | "AWS_SECRET_ACCESS_KEY = ''\n", 38 | "BUCKET_NAME = ''\n", 39 | "LOC_ON_BUCKET = ''" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": null, 45 | "metadata": {}, 46 | "outputs": [], 47 | "source": [ 48 | "filesystem = S3FS(bucket_name=BUCKET_NAME,\n", 49 | " aws_access_key_id=AWS_ACCESS_KEY_ID,\n", 50 | " aws_secret_access_key=AWS_SECRET_ACCESS_KEY)" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": null, 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "config = SHConfig()\n", 60 | "config.aws_access_key_id = AWS_ACCESS_KEY_ID\n", 61 | "config.aws_secret_access_key = AWS_SECRET_ACCESS_KEY" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": null, 67 | "metadata": {}, 68 | "outputs": [], 69 | "source": [ 70 | "BAND_GAIN = {3: 0.006800104616, 2: 0.011123248049, 1: 0.013184818227, 0: 0.014307912429}\n", 71 | "BAND_BIAS = {3: -0.00680010461, 2: -0.01112324804, 1: -0.01318481822, 0: -0.01430791242}\n", 72 | "PAN_GAIN = 0.011354020831\n", 73 | "PAN_BIAS = -0.01135402083" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": null, 79 | "metadata": {}, 80 | "outputs": [], 81 | "source": [ 82 | "MS4_THRESHOLD = 100\n", 83 | "PAN_THRESHOLD = 100" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": null, 89 | "metadata": {}, 90 | "outputs": [], 91 | "source": [ 92 | "metadata_ms4 = pd.read_parquet(filesystem.openbin('metadata/deimos_ms4_metadata.pq'))\n", 93 | "metadata_pan = pd.read_parquet(filesystem.openbin('metadata/deimos_pan_metadata.pq'))" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": null, 99 | "metadata": {}, 100 | "outputs": [], 101 | "source": [ 102 | "metadata_ms4.columns" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": null, 108 | "metadata": {}, 109 | "outputs": [], 110 | "source": [ 111 | "CLM_MASK_BAND = 0 # Blue" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": null, 117 | "metadata": {}, 118 | "outputs": [], 119 | "source": [ 120 | "stats_ms4 = []\n", 121 | "stats_pan = []\n", 122 | "\n", 123 | "\n", 124 | "def calculate_stats(data, sensing_time):\n", 125 | " median = np.median(data, axis=0)\n", 126 | " mean = np.mean(data, axis=0)\n", 127 | " std = np.std(data, axis=0)\n", 128 | "\n", 129 | " stats = {'sensing_time': sensing_time}\n", 130 | " for i, (bmedian, bstd, bmean) in enumerate(zip(median, std, mean)):\n", 131 | " band_stats = {f'STX_CLM_MEDIAN_{i+1}': bmedian,\n", 132 | " f'STX_CLM_STDV_{i+1}': bstd,\n", 133 | " f'STX_CLM_MEAN_{i+1}': bmean}\n", 134 | "\n", 135 | " stats = {**stats, **band_stats}\n", 136 | "\n", 137 | " return stats\n", 138 | "\n", 139 | "\n", 140 | "def calculate_stats_radiance(data, sensing_time):\n", 141 | " _, chnls = data.shape\n", 142 | " if chnls == 1:\n", 143 | " data = data*PAN_GAIN + PAN_BIAS\n", 144 | " elif chnls == 4:\n", 145 | " data = np.add(np.multiply(data, list(BAND_GAIN.values())), list(BAND_BIAS.values()))\n", 146 | " else:\n", 147 | " raise ValueError(\"Wrong number of channels.\")\n", 148 | "\n", 149 | " median = np.median(data, axis=0)\n", 150 | " mean = np.mean(data, axis=0)\n", 151 | " std = np.std(data, axis=0)\n", 152 | "\n", 153 | " stats = {'sensing_time': sensing_time}\n", 154 | " for i, (bmedian, bstd, bmean) in enumerate(zip(median, std, mean)):\n", 155 | " band_stats = {f'STX_CLM_RADIANCE_MEDIAN_{i+1}': bmedian,\n", 156 | " f'STX_CLM_RADIANCE_STDV_{i+1}': bstd,\n", 157 | " f'STX_CLM_RADIANCE_MEAN_{i+1}': bmean}\n", 158 | " stats = {**stats, **band_stats}\n", 159 | "\n", 160 | " return stats\n", 161 | "\n", 162 | "\n", 163 | "def calculate_cloudfree_stats(tile_folder, config, clm_mask_band, band_gain, band_bias, ms4_thr, pan_gain, pan_bias, pan_thr, calculate_stats_func):\n", 164 | " try:\n", 165 | " eop_ms4 = ImportFromTiff((FeatureType.DATA, 'MS4'), folder=tile_folder, config=config).execute(\n", 166 | " filename=['B04.tiff', 'B03.tiff', 'B02.tiff', 'B01.tiff'])\n", 167 | " eop = ImportFromTiff((FeatureType.DATA, 'PAN'), folder=tile_folder,\n", 168 | " config=config).execute(eop_ms4, filename='PAN.tiff')\n", 169 | " eop.timestamp = [datetime.strptime(sensing_time, '%Y-%m-%d_%H-%M-%S')]\n", 170 | " data = eop.data['MS4']\n", 171 | "\n", 172 | " mask = (data[..., clm_mask_band]*band_gain[clm_mask_band] + band_bias[clm_mask_band]) > MS4_THRESHOLD\n", 173 | " mask = mask.astype(np.float32)\n", 174 | " mask[data[..., 0] == 0] = np.nan\n", 175 | " coverage = mask[mask == 1].sum() / np.count_nonzero(~np.isnan(mask))\n", 176 | "\n", 177 | " data_masked = data[mask == 0, :]\n", 178 | " # TODO: Why is this here... Serves me right for not commenting.\n", 179 | " if coverage > 0.1:\n", 180 | " stats_ms4 = calculate_stats_func(data_masked, eop.timestamp[0])\n", 181 | " else:\n", 182 | " stats_ms4 = calculate_stats_func(data[data[..., 0] > 0, :], eop.timestamp[0])\n", 183 | "\n", 184 | " data = eop.data['PAN'].squeeze()\n", 185 | " mask = ((eop.data['PAN']*PAN_GAIN + PAN_BIAS) > PAN_THRESHOLD).squeeze()\n", 186 | " data_masked = data[mask]\n", 187 | " mask = mask.astype(np.float32)\n", 188 | " mask[data == 0] = np.nan\n", 189 | " data_masked = data[mask == 0]\n", 190 | "\n", 191 | " if coverage > 0.1:\n", 192 | " stats_pan = calculate_stats_func(np.expand_dims(data_masked, -1), eop.timestamp[0])\n", 193 | " else:\n", 194 | " stats_pan = calculate_stats_func(np.expand_dims(data[data > 0], -1), eop.timestamp[0])\n", 195 | " return stats_ms4, stats_pan\n", 196 | "\n", 197 | " except Exception as e:\n", 198 | " print(f'Failed for sensing time {sensing_time} with error: {e}')\n", 199 | " return None, None\n", 200 | "\n", 201 | "\n", 202 | "results = []\n", 203 | "for sensing_time in tqdm(filesystem.listdir(LOC_ON_BUCKET)):\n", 204 | " results.append(calculate_cloudfree_stats(tile_folder=os.path.join('s3://', BUCKET_NAME, LOC_ON_BUCKET, sensing_time),\n", 205 | " config=config,\n", 206 | " clm_mask_band=CLM_MASK_BAND,\n", 207 | " band_gain=BAND_GAIN,\n", 208 | " band_bias=BAND_BIAS,\n", 209 | " ms4_thr=MS4_THRESHOLD,\n", 210 | " pan_gain=PAN_GAIN,\n", 211 | " pan_bias=PAN_BIAS,\n", 212 | " pan_thr=PAN_THRESHOLD,\n", 213 | " calculate_stats_func=calculate_stats\n", 214 | " ))" 215 | ] 216 | }, 217 | { 218 | "cell_type": "code", 219 | "execution_count": null, 220 | "metadata": {}, 221 | "outputs": [], 222 | "source": [ 223 | "stats_ms4, stats_pan = list(zip(*results))" 224 | ] 225 | }, 226 | { 227 | "cell_type": "code", 228 | "execution_count": null, 229 | "metadata": {}, 230 | "outputs": [], 231 | "source": [ 232 | "stats_ms4" 233 | ] 234 | }, 235 | { 236 | "cell_type": "code", 237 | "execution_count": null, 238 | "metadata": {}, 239 | "outputs": [], 240 | "source": [ 241 | "ms4 = pd.DataFrame([x for x in stats_ms4 if x is not None])\n", 242 | "ms4.sensing_time = ms4.sensing_time.apply(lambda x: str(x).replace(' ', 'T'))\n", 243 | "metadata_ms4_stats = metadata_ms4.set_index('START_TIME').join(ms4.set_index('sensing_time')).reset_index()\n", 244 | "with filesystem.openbin('metadata/deimos_ms4_metadata.pq', 'wb') as f:\n", 245 | " metadata_ms4_stats.to_parquet(f)" 246 | ] 247 | }, 248 | { 249 | "cell_type": "code", 250 | "execution_count": null, 251 | "metadata": {}, 252 | "outputs": [], 253 | "source": [ 254 | "pan = pd.DataFrame([x for x in stats_pan if x is not None])\n", 255 | "pan.sensing_time = pan.sensing_time.apply(lambda x: str(x).replace(' ', 'T'))\n", 256 | "pan_stats = metadata_pan.set_index('START_TIME').join(pan.set_index('sensing_time')).reset_index()\n", 257 | "with filesystem.openbin('metadata/deimos_pan_metadata.pq', 'wb') as f:\n", 258 | " pan_stats.to_parquet(f)" 259 | ] 260 | } 261 | ], 262 | "metadata": { 263 | "kernelspec": { 264 | "display_name": "Environment (conda_tensorflow2_p36)", 265 | "language": "python", 266 | "name": "conda_tensorflow2_p36" 267 | } 268 | }, 269 | "nbformat": 4, 270 | "nbformat_minor": 2 271 | } -------------------------------------------------------------------------------- /notebooks/01-download-to-eopatches.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%matplotlib inline" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": { 16 | "lines_to_end_of_cell_marker": 2 17 | }, 18 | "outputs": [], 19 | "source": [ 20 | "from datetime import timedelta\n", 21 | "\n", 22 | "import geopandas as gpd\n", 23 | "import numpy as np\n", 24 | "import pandas as pd\n", 25 | "from eolearn.core import (\n", 26 | " EOExecutor,\n", 27 | " EOPatch,\n", 28 | " EOTask,\n", 29 | " EOWorkflow,\n", 30 | " FeatureType,\n", 31 | " OverwritePermission,\n", 32 | " SaveTask,\n", 33 | ")\n", 34 | "from eolearn.io import SentinelHubInputTask\n", 35 | "from fs_s3fs import S3FS\n", 36 | "from matplotlib import pyplot as plt\n", 37 | "from sentinelhub import DataCollection, SHConfig, UtmZoneSplitter" 38 | ] 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "metadata": {}, 43 | "source": [ 44 | "# Config" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": null, 50 | "metadata": {}, 51 | "outputs": [], 52 | "source": [ 53 | "config = SHConfig()\n", 54 | "config.sh_client_id = ''\n", 55 | "config.sh_client_secret = ''\n", 56 | "config.instance_id = ''\n", 57 | "config.aws_access_key_id = ''\n", 58 | "config.aws_secret_access_key = ''" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": null, 64 | "metadata": {}, 65 | "outputs": [], 66 | "source": [ 67 | "test_area_cyprus = gpd.read_file('../input/TestAreaCyprus.geojson')\n", 68 | "test_area_lithuania = gpd.read_file('../input/TestAreaLithuania.geojson')" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": null, 74 | "metadata": {}, 75 | "outputs": [], 76 | "source": [ 77 | "filesystem = S3FS(bucket_name='',\n", 78 | " aws_access_key_id=config.aws_access_key_id,\n", 79 | " aws_secret_access_key=config.aws_secret_access_key)" 80 | ] 81 | }, 82 | { 83 | "cell_type": "markdown", 84 | "metadata": {}, 85 | "source": [ 86 | "# Load metadata" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": null, 92 | "metadata": {}, 93 | "outputs": [], 94 | "source": [ 95 | "metadata_ms4 = pd.read_parquet(filesystem.openbin('metadata/deimos_ms4_metadata.pq'))\n", 96 | "metadata_pan = pd.read_parquet(filesystem.openbin('metadata/deimos_pan_metadata.pq'))" 97 | ] 98 | }, 99 | { 100 | "cell_type": "markdown", 101 | "metadata": {}, 102 | "source": [ 103 | "# Split into small bboxes" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": null, 109 | "metadata": {}, 110 | "outputs": [], 111 | "source": [ 112 | "bbox_splitter = UtmZoneSplitter(test_area_lithuania.geometry.to_list(\n", 113 | ") + test_area_cyprus.geometry.to_list(), crs=test_area_lithuania.crs, bbox_size=2000)" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": null, 119 | "metadata": {}, 120 | "outputs": [], 121 | "source": [ 122 | "bboxes = bbox_splitter.get_bbox_list(buffer=0.2)" 123 | ] 124 | }, 125 | { 126 | "cell_type": "markdown", 127 | "metadata": {}, 128 | "source": [ 129 | "# Download data from SH service" 130 | ] 131 | }, 132 | { 133 | "cell_type": "markdown", 134 | "metadata": {}, 135 | "source": [ 136 | "Downloads data:\n", 137 | "\n", 138 | "1. Sentinel-2 data in interval 2020-04-01 - 2020-10-01\n", 139 | "2. Deimos MS4 bands\n", 140 | "3. Deimos PAN band\n", 141 | "4. Deimos pansharpened\n", 142 | "\n", 143 | "Downloads everything as digital numbers. The Deimos data is saved to one eopatch, S-2 data to another. This is due to the fact that the acquisition timestamps are different for these two data sources. Metadata is also added to S-2 bands that can be used for normalizing." 144 | ] 145 | }, 146 | { 147 | "cell_type": "code", 148 | "execution_count": null, 149 | "metadata": {}, 150 | "outputs": [], 151 | "source": [ 152 | "evalscript_pansharp = '''\n", 153 | "//VERSION=3\n", 154 | "\n", 155 | "function setup() {\n", 156 | " return {\n", 157 | " input: [{\n", 158 | " bands: [\"B01\", \"B02\", \"B03\", \"B04\", \"PAN\", \"dataMask\"],\n", 159 | " units: [\"DN\", \"DN\", \"DN\", \"DN\", \"DN\", \"DN\"]\n", 160 | " }],\n", 161 | " output: [\n", 162 | " { id:\"bands\", bands:4, sampleType: SampleType.UINT16 }, \n", 163 | " { id:\"bool_mask\", bands:1, sampleType: SampleType.UINT8 },\n", 164 | " ]\n", 165 | " }\n", 166 | "}\n", 167 | "\n", 168 | "function updateOutputMetadata(scenes, inputMetadata, outputMetadata) {\n", 169 | " outputMetadata.userData = { \"norm_factor\": inputMetadata.normalizationFactor }\n", 170 | "}\n", 171 | "\n", 172 | "function evaluatePixel(sample) {\n", 173 | " let sudoPanW = (sample.B01 + sample.B02 + sample.B03 + sample.B04) / 4\n", 174 | " let ratioW = sample.PAN / sudoPanW\n", 175 | " let red = sample.B02 * ratioW\n", 176 | " let green = sample.B03 * ratioW\n", 177 | " let blue = sample.B04 * ratioW\n", 178 | " let nir = sample.B01 * ratioW \n", 179 | " return {bands: [blue, green, red, nir], bool_mask: [sample.dataMask]};\n", 180 | "}\n", 181 | "'''" 182 | ] 183 | }, 184 | { 185 | "cell_type": "code", 186 | "execution_count": null, 187 | "metadata": {}, 188 | "outputs": [], 189 | "source": [ 190 | "# Download Sentinel-2 data\n", 191 | "\n", 192 | "get_s2_data = SentinelHubInputTask(\n", 193 | " bands_feature=(FeatureType.DATA, 'BANDS'),\n", 194 | " bands=['B02', 'B03', 'B04', 'B08'],\n", 195 | " resolution=10,\n", 196 | " maxcc=0.5,\n", 197 | " time_difference=timedelta(minutes=120),\n", 198 | " data_collection=DataCollection.SENTINEL2_L1C,\n", 199 | " additional_data=[(FeatureType.MASK, 'dataMask', 'IS_DATA'),\n", 200 | " (FeatureType.MASK, 'CLM'),\n", 201 | " (FeatureType.DATA, 'CLP')],\n", 202 | " max_threads=5,\n", 203 | " config=config,\n", 204 | " bands_dtype=np.uint16\n", 205 | ")\n", 206 | "\n", 207 | "# Download pansharpened Deimos.\n", 208 | "get_deimos_data_pansharpened = SentinelHubInputTask(\n", 209 | " bands_feature=(FeatureType.DATA, 'BANDS-DEIMOS'),\n", 210 | " bands=['B01', 'B02', 'B03', 'B04'], # B, G, R, NIR\n", 211 | " resolution=2.5,\n", 212 | " time_difference=timedelta(minutes=120),\n", 213 | " data_collection=DataCollection.define_byoc(''), # INPUT BYOC COLLECTION ID HERE\n", 214 | " additional_data=[(FeatureType.MASK, 'dataMask', 'IS_DATA')],\n", 215 | " max_threads=5,\n", 216 | " evalscript=evalscript_pansharp,\n", 217 | " config=config,\n", 218 | " bands_dtype=np.uint16,\n", 219 | " aux_request_args=dict(processing=dict(upsampling='BICUBIC',\n", 220 | " downsampling='BICUBIC'))\n", 221 | ")\n", 222 | "\n", 223 | "\n", 224 | "save_s2 = SaveTask('', # INPUT WHERE TO SAVE S-2 PATCHES\n", 225 | " config=config,\n", 226 | " overwrite_permission=OverwritePermission.OVERWRITE_PATCH)\n", 227 | "save_dm = SaveTask('', # INPUT WHERE TO SAVE Deimos PATCHES\n", 228 | " config=config,\n", 229 | " overwrite_permission=OverwritePermission.OVERWRITE_FEATURES)" 230 | ] 231 | }, 232 | { 233 | "cell_type": "code", 234 | "execution_count": null, 235 | "metadata": {}, 236 | "outputs": [], 237 | "source": [ 238 | "class AddMetaData(EOTask):\n", 239 | " \"\"\" Adds metadata to Deimos EOPatch. Uses dataframe with metadata that was parsed from tile .dim files. \"\"\"\n", 240 | "\n", 241 | " def __init__(self, metadata_ms4, metadata_pan):\n", 242 | " self.metadata_ms4 = metadata_ms4\n", 243 | " self.metadata_pan = metadata_pan\n", 244 | "\n", 245 | " def _create_meta_data_dict(self, eop):\n", 246 | " meta_info_dict = {}\n", 247 | " for ts in eop.timestamp:\n", 248 | " metadata_pan_ts = self.metadata_pan[self.metadata_pan.START_TIME ==\n", 249 | " ts.strftime('%Y-%m-%dT%H:%M:%S')].reset_index()\n", 250 | " metadata_ms4_ts = self.metadata_ms4[self.metadata_ms4.START_TIME ==\n", 251 | " ts.strftime('%Y-%m-%dT%H:%M:%S')].reset_index()\n", 252 | "\n", 253 | " bands_stats_ms4 = metadata_ms4_ts[['STX_STDV_1', 'STX_MEAN_1', 'STX_MIN_1',\n", 254 | " 'STX_MAX_1', 'STX_STDV_2', 'STX_MEAN_2',\n", 255 | " 'STX_MIN_2', 'STX_MAX_2', 'STX_STDV_3',\n", 256 | " 'STX_MEAN_3', 'STX_MIN_3', 'STX_MAX_3',\n", 257 | " 'STX_STDV_4', 'STX_MEAN_4', 'STX_MIN_4', 'STX_MAX_4']].to_dict(orient='index')[0]\n", 258 | "\n", 259 | " bands_physical_info_ms4 = metadata_ms4_ts[['PHYSICAL_GAIN_1', 'PHYSICAL_BIAS_1', 'PHYSICAL_UNIT_1',\n", 260 | " 'ESUN_1', 'PHYSICAL_GAIN_2', 'PHYSICAL_BIAS_2',\n", 261 | " 'PHYSICAL_UNIT_2', 'ESUN_2', 'PHYSICAL_GAIN_3',\n", 262 | " 'PHYSICAL_BIAS_3', 'PHYSICAL_UNIT_3', 'ESUN_3',\n", 263 | " 'PHYSICAL_GAIN_4', 'PHYSICAL_BIAS_4', 'PHYSICAL_UNIT_4',\n", 264 | " 'ESUN_4']].to_dict(orient='index')[0]\n", 265 | "\n", 266 | " bands_stats_pan = metadata_pan_ts[['STX_STDV_4', 'STX_MEAN_4',\n", 267 | " 'STX_MIN_4', 'STX_MAX_4']].to_dict(orient='index')[0]\n", 268 | "\n", 269 | " bands_physical_info_pan = metadata_pan_ts[['PHYSICAL_GAIN_4', 'PHYSICAL_BIAS_4',\n", 270 | " 'PHYSICAL_UNIT_4',\n", 271 | " 'ESUN_4']].to_dict(orient='index')[0]\n", 272 | "\n", 273 | " meta_info_dict[ts] = {'MS4': {'BAND_STATS': bands_stats_ms4, 'PHYSICAL_INFO': bands_physical_info_ms4},\n", 274 | " 'PAN': {'BAND_STATS': bands_stats_pan, 'PHYSICAL_INFO': bands_physical_info_pan}}\n", 275 | " return meta_info_dict\n", 276 | "\n", 277 | " def execute(self, eop):\n", 278 | " meta_data_dict = self._create_meta_data_dict(eop)\n", 279 | " eop.meta_info['metadata'] = meta_data_dict\n", 280 | " return eop" 281 | ] 282 | }, 283 | { 284 | "cell_type": "code", 285 | "execution_count": null, 286 | "metadata": {}, 287 | "outputs": [], 288 | "source": [ 289 | "add_metadata = AddMetaData(metadata_ms4, metadata_pan)" 290 | ] 291 | }, 292 | { 293 | "cell_type": "code", 294 | "execution_count": null, 295 | "metadata": {}, 296 | "outputs": [], 297 | "source": [ 298 | "# Execute the workflow\n", 299 | "time_interval = ['2020-04-01', '2020-10-31'] # time interval for the SH request\n", 300 | "\n", 301 | "# define additional parameters of the workflow\n", 302 | "execution_args = []\n", 303 | "for idx, bbox in enumerate(bboxes):\n", 304 | " execution_args.append({\n", 305 | " get_deimos_data_pansharpened: {'bbox': bbox, 'time_interval': time_interval},\n", 306 | " get_s2_data: {'bbox': bbox, 'time_interval': time_interval},\n", 307 | " save_s2: {'eopatch_folder': f'eopatch-{idx:04d}'},\n", 308 | " save_dm: {'eopatch_folder': f'eopatch-{idx:04d}'}\n", 309 | " })" 310 | ] 311 | }, 312 | { 313 | "cell_type": "code", 314 | "execution_count": null, 315 | "metadata": {}, 316 | "outputs": [], 317 | "source": [ 318 | "workflow = EOWorkflow([\n", 319 | " (get_s2_data, [], 'Get S2 data'),\n", 320 | " (get_deimos_data_pansharpened, [], 'Get Deimos pansharpened'),\n", 321 | " (add_metadata, [get_deimos_data_pansharpened], 'Add metadata to DEIMOS'),\n", 322 | " (save_s2, [get_s2_data], 'save S2 data'),\n", 323 | " (save_dm, [add_metadata], 'save deimos data')\n", 324 | "])" 325 | ] 326 | }, 327 | { 328 | "cell_type": "code", 329 | "execution_count": null, 330 | "metadata": {}, 331 | "outputs": [], 332 | "source": [ 333 | "len(execution_args)" 334 | ] 335 | }, 336 | { 337 | "cell_type": "code", 338 | "execution_count": null, 339 | "metadata": {}, 340 | "outputs": [], 341 | "source": [ 342 | "executor = EOExecutor(workflow, execution_args, save_logs=True)\n", 343 | "results = executor.run(workers=10, multiprocess=False)\n", 344 | "executor.make_report()" 345 | ] 346 | }, 347 | { 348 | "cell_type": "code", 349 | "execution_count": null, 350 | "metadata": {}, 351 | "outputs": [], 352 | "source": [ 353 | "eop_d = EOPatch.load('', # SAMPLE DEIMOS PATCH PATH\n", 354 | " filesystem=filesystem)\n", 355 | "eop_s2 = EOPatch.load('', # SAMPLE S-2 PATCH PATH\n", 356 | " filesystem=filesystem)" 357 | ] 358 | }, 359 | { 360 | "cell_type": "code", 361 | "execution_count": null, 362 | "metadata": {}, 363 | "outputs": [], 364 | "source": [ 365 | "eop_d" 366 | ] 367 | }, 368 | { 369 | "cell_type": "code", 370 | "execution_count": null, 371 | "metadata": {}, 372 | "outputs": [], 373 | "source": [ 374 | "fig, ax = plt.subplots(2, figsize=(15, 15))\n", 375 | "ax[0].imshow(eop_d.data['BANDS-DEIMOS'][5][..., [2, 1, 0]] / 10000 * 1.5)\n", 376 | "ax[1].imshow(eop_s2.data['BANDS'][-7][..., [2, 1, 0]] / 10000 * 3.5)" 377 | ] 378 | } 379 | ], 380 | "metadata": { 381 | "kernelspec": { 382 | "display_name": "Python 3", 383 | "language": "python", 384 | "name": "python3" 385 | }, 386 | "language_info": { 387 | "codemirror_mode": { 388 | "name": "ipython", 389 | "version": 3 390 | }, 391 | "file_extension": ".py", 392 | "mimetype": "text/x-python", 393 | "name": "python", 394 | "nbconvert_exporter": "python", 395 | "pygments_lexer": "ipython3", 396 | "version": "3.6.6" 397 | } 398 | }, 399 | "nbformat": 4, 400 | "nbformat_minor": 2 401 | } -------------------------------------------------------------------------------- /notebooks/02a-add-clm-deimos.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os\n", 10 | "import numpy as np\n", 11 | "from eolearn.core import EOPatch, FeatureType, OverwritePermission\n", 12 | "from fs_s3fs import S3FS\n", 13 | "from matplotlib import pyplot as plt\n", 14 | "from sentinelhub import SHConfig\n", 15 | "from sg_utils.processing import multiprocess" 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": null, 21 | "metadata": {}, 22 | "outputs": [], 23 | "source": [ 24 | "config = SHConfig()\n", 25 | "config.sh_client_id = ''\n", 26 | "config.sh_client_secret = ''\n", 27 | "config.instance_id = ''\n", 28 | "config.aws_access_key_id = ''\n", 29 | "config.aws_secret_access_key = ''" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": null, 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "filesystem = S3FS(bucket_name='',\n", 39 | " aws_access_key_id=config.aws_access_key_id,\n", 40 | " aws_secret_access_key=config.aws_secret_access_key)" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": null, 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "DIR_DEIMOS = ''" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": null, 55 | "metadata": {}, 56 | "outputs": [], 57 | "source": [ 58 | "def _clms_eops(eop, threshold=95, band=0):\n", 59 | "\n", 60 | " clms = []\n", 61 | " for i, ts in enumerate(eop.timestamp):\n", 62 | " gain = float(eop.meta_info['metadata'][ts]['MS4']['PHYSICAL_INFO'][f'PHYSICAL_GAIN_{band+1}'])\n", 63 | " bias = float(eop.meta_info['metadata'][ts]['MS4']['PHYSICAL_INFO'][f'PHYSICAL_BIAS_{band+1}'])\n", 64 | " clms.append(((eop.data['BANDS-DEIMOS'][i, ..., band]*gain + bias) > threshold))\n", 65 | " return np.array(clms)" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": null, 71 | "metadata": {}, 72 | "outputs": [], 73 | "source": [ 74 | "def add_clm(eop_path):\n", 75 | " try:\n", 76 | " deim_eop = EOPatch.load(eop_path, filesystem=filesystem, lazy_loading=True)\n", 77 | " deim_eop.mask['CLM'] = np.expand_dims(_clms_eops(deim_eop), -1)\n", 78 | " deim_eop.save(path=eop_path, filesystem=filesystem, features=[(FeatureType.MASK, 'CLM')],\n", 79 | " overwrite_permission=OverwritePermission.OVERWRITE_FEATURES)\n", 80 | " return True, eop_path\n", 81 | " except Exception:\n", 82 | " return False, eop_path" 83 | ] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "execution_count": null, 88 | "metadata": {}, 89 | "outputs": [], 90 | "source": [ 91 | "eops_paths = [os.path.join(DIR_DEIMOS, x) for x in filesystem.listdir(DIR_DEIMOS)]" 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": null, 97 | "metadata": {}, 98 | "outputs": [], 99 | "source": [ 100 | "results = multiprocess(add_clm, eops_paths, max_workers=4)" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": null, 106 | "metadata": { 107 | "lines_to_next_cell": 2 108 | }, 109 | "outputs": [], 110 | "source": [ 111 | "deim_eop = EOPatch.load(eops_paths[850], filesystem=filesystem, lazy_loading=True)" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": null, 117 | "metadata": {}, 118 | "outputs": [], 119 | "source": [ 120 | "fig, ax = plt.subplots(2, figsize=(15, 15))\n", 121 | "tidx = 0\n", 122 | "ax[0].imshow(deim_eop.data['BANDS-DEIMOS'][tidx][..., [2, 1, 0]].squeeze()*(1/10000))\n", 123 | "ax[1].imshow(deim_eop.mask['CLM'][tidx].squeeze())\n", 124 | "ax[1].set_title(f\"CCOV: {deim_eop.mask['CLM'][tidx].mean()}\")" 125 | ] 126 | } 127 | ], 128 | "metadata": { 129 | "kernelspec": { 130 | "display_name": "Environment (conda_tensorflow2_p36)", 131 | "language": "python", 132 | "name": "conda_tensorflow2_p36" 133 | }, 134 | "language_info": { 135 | "codemirror_mode": { 136 | "name": "ipython", 137 | "version": 3 138 | }, 139 | "file_extension": ".py", 140 | "mimetype": "text/x-python", 141 | "name": "python", 142 | "nbconvert_exporter": "python", 143 | "pygments_lexer": "ipython3", 144 | "version": "3.6.5" 145 | } 146 | }, 147 | "nbformat": 4, 148 | "nbformat_minor": 2 149 | } -------------------------------------------------------------------------------- /notebooks/02b-add-clm-stats-to-patches.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import pandas as pd\n", 10 | "from eolearn.core import EOPatch, FeatureType, OverwritePermission\n", 11 | "from fs_s3fs import S3FS\n", 12 | "from sentinelhub import SHConfig\n", 13 | "from sg_utils.processing import multiprocess" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": null, 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "AWS_ACCESS_KEY_ID = ''\n", 23 | "AWS_SECRET_ACCESS_KEY = ''" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": null, 29 | "metadata": {}, 30 | "outputs": [], 31 | "source": [ 32 | "filesystem = S3FS(bucket_name='',\n", 33 | " aws_access_key_id=AWS_ACCESS_KEY_ID,\n", 34 | " aws_secret_access_key=AWS_SECRET_ACCESS_KEY,\n", 35 | " region='eu-central-1')" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": null, 41 | "metadata": {}, 42 | "outputs": [], 43 | "source": [ 44 | "config = SHConfig()\n", 45 | "config.aws_access_key_id = AWS_ACCESS_KEY_ID\n", 46 | "config.aws_secret_access_key = AWS_SECRET_ACCESS_KEY" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": null, 52 | "metadata": {}, 53 | "outputs": [], 54 | "source": [ 55 | "metadata_ms4 = pd.read_parquet(filesystem.openbin('metadata/deimos_ms4_metadata.pq'))\n", 56 | "metadata_pan = pd.read_parquet(filesystem.openbin('metadata/deimos_pan_metadata.pq'))\n", 57 | "\n", 58 | "metadata_ms4.START_TIME = pd.to_datetime(metadata_ms4.START_TIME)" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": null, 64 | "metadata": {}, 65 | "outputs": [], 66 | "source": [ 67 | "DEIMOS_DIR = ''" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": null, 73 | "metadata": {}, 74 | "outputs": [], 75 | "source": [ 76 | "metadata_pan.START_TIME = pd.to_datetime(metadata_pan.START_TIME)" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": null, 82 | "metadata": {}, 83 | "outputs": [], 84 | "source": [ 85 | "def add_new_metadata(eop_path):\n", 86 | " eop = EOPatch.load(f'{DEIMOS_DIR}/{eop_path}', filesystem=filesystem, lazy_loading=True)\n", 87 | " meta_info = eop.meta_info\n", 88 | " for timestamp in meta_info['metadata']:\n", 89 | " stats = metadata_ms4[metadata_ms4.START_TIME == timestamp]\n", 90 | " stats_pan = metadata_pan[metadata_pan.START_TIME == timestamp]\n", 91 | " # Add \"cloudy\" median\n", 92 | " meta_info['metadata'][timestamp]['MS4']['CLM_BAND_STATS_'] = {}\n", 93 | " meta_info['metadata'][timestamp]['PAN']['CLM_BAND_STATS'] = {}\n", 94 | "\n", 95 | " for i in range(0, 4):\n", 96 | " meta_info['metadata'][timestamp]['MS4']['BAND_STATS'][f'STX_MEDIAN_{i+1}'] = stats[f'STX_MEDIAN_{i+1}'].iloc[0]\n", 97 | " meta_info['metadata'][timestamp]['MS4']['CLM_BAND_STATS'][\n", 98 | " f'STX_CLM_MEDIAN_{i+1}'] = stats[f'STX_CLM_MEDIAN_{i+1}'].iloc[0]\n", 99 | " meta_info['metadata'][timestamp]['MS4']['CLM_BAND_STATS'][\n", 100 | " f'STX_CLM_MEAN_{i+1}'] = stats[f'STX_CLM_MEAN_{i+1}'].iloc[0]\n", 101 | " meta_info['metadata'][timestamp]['MS4']['CLM_BAND_STATS'][\n", 102 | " f'STX_CLM_STDV_{i+1}'] = stats[f'STX_CLM_STDV_{i+1}'].iloc[0]\n", 103 | "\n", 104 | " meta_info['metadata'][timestamp]['PAN']['BAND_STATS'][f'STX_MEDIAN_1'] = stats_pan[f'STX_MEDIAN_1'].iloc[0]\n", 105 | " meta_info['metadata'][timestamp]['PAN']['CLM_BAND_STATS'][f'STX_CLM_MEDIAN_1'] = stats_pan[f'STX_CLM_MEDIAN_1'].iloc[0]\n", 106 | " meta_info['metadata'][timestamp]['PAN']['CLM_BAND_STATS'][f'STX_CLM_MEAN_1'] = stats_pan[f'STX_CLM_MEAN_1'].iloc[0]\n", 107 | " meta_info['metadata'][timestamp]['PAN']['CLM_BAND_STATS'][f'STX_CLM_STDV_1'] = stats_pan[f'STX_CLM_STDV_1'].iloc[0]\n", 108 | "\n", 109 | " eop.save(path=f'{DEIMOS_DIR}/{eop_path}', filesystem=filesystem,\n", 110 | " features=[FeatureType.META_INFO], overwrite_permission=OverwritePermission.OVERWRITE_FEATURES)\n", 111 | "\n", 112 | "\n", 113 | "def add_new_metadata_pansharpened(eop_path):\n", 114 | " eop = EOPatch.load(f'{DEIMOS_DIR}/{eop_path}', filesystem=filesystem, lazy_loading=True)\n", 115 | " meta_info = eop.meta_info\n", 116 | " for timestamp in meta_info['metadata']:\n", 117 | " stats = metadata_ms4[metadata_ms4.START_TIME == timestamp]\n", 118 | " stats_pan = metadata_pan[metadata_pan.START_TIME == timestamp]\n", 119 | " # Add \"cloudy\" median\n", 120 | " meta_info['metadata'][timestamp]['MS4']['CLM_BAND_STATS_PANSHARPENED'] = {}\n", 121 | "\n", 122 | " for i in range(0, 4):\n", 123 | " meta_info['metadata'][timestamp]['MS4']['CLM_BAND_STATS_PANSHARPENED'][\n", 124 | " f'STX_CLM_MEDIAN_PANSHARPENED_{i+1}'] = stats[f'STX_CLM_MEDIAN_PANSHARPENED_{i+1}'].iloc[0]\n", 125 | " meta_info['metadata'][timestamp]['MS4']['CLM_BAND_STATS_PANSHARPENED'][\n", 126 | " f'STX_CLM_MEAN_PANSHARPENED_{i+1}'] = stats[f'STX_CLM_MEAN_PANSHARPENED_{i+1}'].iloc[0]\n", 127 | " meta_info['metadata'][timestamp]['MS4']['CLM_BAND_STATS_PANSHARPENED'][\n", 128 | " f'STX_CLM_STDV_PANSHARPENED_{i+1}'] = stats[f'STX_CLM_STDV_PANSHARPENED_{i+1}'].iloc[0]\n", 129 | "\n", 130 | " eop.save(path=f'{DEIMOS_DIR}/{eop_path}', filesystem=filesystem,\n", 131 | " features=[FeatureType.META_INFO], overwrite_permission=OverwritePermission.OVERWRITE_FEATURES)\n", 132 | "\n", 133 | "\n", 134 | "def add_new_metadata_radiance(eop_path):\n", 135 | " eop = EOPatch.load(f'{DEIMOS_DIR}/{eop_path}', filesystem=filesystem, lazy_loading=True)\n", 136 | " meta_info = eop.meta_info\n", 137 | " for timestamp in meta_info['metadata']:\n", 138 | " stats = metadata_ms4[metadata_ms4.START_TIME == timestamp]\n", 139 | " stats_pan = metadata_pan[metadata_pan.START_TIME == timestamp]\n", 140 | " # Add \"cloudy\" median\n", 141 | " meta_info['metadata'][timestamp]['MS4']['CLM_RADIANCE_BAND_STATS'] = {}\n", 142 | " meta_info['metadata'][timestamp]['PAN']['CLM_RADIANCE_BAND_STATS'] = {}\n", 143 | "\n", 144 | " for i in range(0, 4):\n", 145 | " meta_info['metadata'][timestamp]['MS4']['CLM_RADIANCE_BAND_STATS'][\n", 146 | " f'STX_CLM_RADIANCE_MEDIAN_{i+1}'] = stats[f'STX_CLM_RADIANCE_MEDIAN_{i+1}'].iloc[0]\n", 147 | " meta_info['metadata'][timestamp]['MS4']['CLM_RADIANCE_BAND_STATS'][\n", 148 | " f'STX_CLM_RADIANCE_MEAN_{i+1}'] = stats[f'STX_CLM_RADIANCE_MEAN_{i+1}'].iloc[0]\n", 149 | " meta_info['metadata'][timestamp]['MS4']['CLM_RADIANCE_BAND_STATS'][\n", 150 | " f'STX_CLM_RADIANCE_STDV_{i+1}'] = stats[f'STX_CLM_RADIANCE_STDV_{i+1}'].iloc[0]\n", 151 | "\n", 152 | " meta_info['metadata'][timestamp]['PAN']['CLM_RADIANCE_BAND_STATS'][f'STX_CLM_RADIANCE_MEDIAN_1'] = stats_pan[f'STX_CLM_RADIANCE_MEDIAN_1'].iloc[0]\n", 153 | " meta_info['metadata'][timestamp]['PAN']['CLM_RADIANCE_BAND_STATS'][f'STX_CLM_RADIANCE_MEAN_1'] = stats_pan[f'STX_CLM_RADIANCE_MEAN_1'].iloc[0]\n", 154 | " meta_info['metadata'][timestamp]['PAN']['CLM_RADIANCE_BAND_STATS'][f'STX_CLM_RADIANCE_STDV_1'] = stats_pan[f'STX_CLM_RADIANCE_STDV_1'].iloc[0]\n", 155 | "\n", 156 | " eop.save(path=f'{DEIMOS_DIR}/{eop_path}', filesystem=filesystem,\n", 157 | " features=[FeatureType.META_INFO], overwrite_permission=OverwritePermission.OVERWRITE_FEATURES)\n", 158 | "\n", 159 | "\n", 160 | "def add_new_metadata_radiance_pansharpened(eop_path):\n", 161 | " eop = EOPatch.load(f'{DEIMOS_DIR}/{eop_path}', filesystem=filesystem, lazy_loading=True)\n", 162 | " meta_info = eop.meta_info\n", 163 | " for timestamp in meta_info['metadata']:\n", 164 | " stats = metadata_ms4[metadata_ms4.START_TIME == timestamp]\n", 165 | " stats_pan = metadata_pan[metadata_pan.START_TIME == timestamp]\n", 166 | " # Add \"cloudy\" median\n", 167 | " meta_info['metadata'][timestamp]['MS4']['CLM_RADIANCE_BAND_STATS_PANSHARPENED'] = {}\n", 168 | "\n", 169 | " for i in range(0, 4):\n", 170 | " meta_info['metadata'][timestamp]['MS4']['CLM_RADIANCE_BAND_STATS_PANSHARPENED'][\n", 171 | " f'STX_CLM_RADIANCE_MEDIAN_PANSHARPENED_{i+1}'] = stats[f'STX_CLM_RADIANCE_MEDIAN_PANSHARPENED_{i+1}'].iloc[0]\n", 172 | " meta_info['metadata'][timestamp]['MS4']['CLM_RADIANCE_BAND_STATS_PANSHARPENED'][\n", 173 | " f'STX_CLM_RADIANCE_MEAN_PANSHARPENED_{i+1}'] = stats[f'STX_CLM_RADIANCE_MEAN_PANSHARPENED_{i+1}'].iloc[0]\n", 174 | " meta_info['metadata'][timestamp]['MS4']['CLM_RADIANCE_BAND_STATS_PANSHARPENED'][\n", 175 | " f'STX_CLM_RADIANCE_STDV_PANSHARPENED_{i+1}'] = stats[f'STX_CLM_RADIANCE_STDV_PANSHARPENED_{i+1}'].iloc[0]\n", 176 | "\n", 177 | " eop.save(path=f'{DEIMOS_DIR}/{eop_path}', filesystem=filesystem,\n", 178 | " features=[FeatureType.META_INFO], overwrite_permission=OverwritePermission.OVERWRITE_FEATURES)" 179 | ] 180 | }, 181 | { 182 | "cell_type": "code", 183 | "execution_count": null, 184 | "metadata": { 185 | "lines_to_next_cell": 2 186 | }, 187 | "outputs": [], 188 | "source": [ 189 | "multiprocess(add_new_metadata_radiance, filesystem.listdir(DEIMOS_DIR), max_workers=15)" 190 | ] 191 | }, 192 | { 193 | "cell_type": "code", 194 | "execution_count": null, 195 | "metadata": { 196 | "lines_to_next_cell": 2 197 | }, 198 | "outputs": [], 199 | "source": [ 200 | "multiprocess(add_new_metadata_radiance_pansharpened, filesystem.listdir(DEIMOS_DIR), max_workers=15)" 201 | ] 202 | }, 203 | { 204 | "cell_type": "code", 205 | "execution_count": null, 206 | "metadata": { 207 | "lines_to_next_cell": 2 208 | }, 209 | "outputs": [], 210 | "source": [ 211 | "eop = EOPatch.load(f'{DEIMOS_DIR}/{ filesystem.listdir(DEIMOS_DIR)[0]}', filesystem=filesystem, lazy_loading=True)" 212 | ] 213 | }, 214 | { 215 | "cell_type": "code", 216 | "execution_count": null, 217 | "metadata": {}, 218 | "outputs": [], 219 | "source": [ 220 | "multiprocess(add_new_metadata, filesystem.listdir(DEIMOS_DIR), max_workers=15)" 221 | ] 222 | }, 223 | { 224 | "cell_type": "code", 225 | "execution_count": null, 226 | "metadata": {}, 227 | "outputs": [], 228 | "source": [ 229 | "multiprocess(add_new_metadata_pansharpened, filesystem.listdir(DEIMOS_DIR), max_workers=15)" 230 | ] 231 | }, 232 | { 233 | "cell_type": "code", 234 | "execution_count": null, 235 | "metadata": {}, 236 | "outputs": [], 237 | "source": [ 238 | "DEIMOS_DIR = ''" 239 | ] 240 | }, 241 | { 242 | "cell_type": "code", 243 | "execution_count": null, 244 | "metadata": {}, 245 | "outputs": [], 246 | "source": [ 247 | "multiprocess(add_new_metadata_radiance_pansharpened, filesystem.listdir(DEIMOS_DIR), max_workers=15)" 248 | ] 249 | }, 250 | { 251 | "cell_type": "code", 252 | "execution_count": null, 253 | "metadata": {}, 254 | "outputs": [], 255 | "source": [ 256 | "multiprocess(add_new_metadata_pansharpened, filesystem.listdir(DEIMOS_DIR), max_workers=15)" 257 | ] 258 | } 259 | ], 260 | "metadata": { 261 | "kernelspec": { 262 | "display_name": "Environment (conda_tensorflow2_p36)", 263 | "language": "python", 264 | "name": "conda_tensorflow2_p36" 265 | }, 266 | "language_info": { 267 | "codemirror_mode": { 268 | "name": "ipython", 269 | "version": 3 270 | }, 271 | "file_extension": ".py", 272 | "mimetype": "text/x-python", 273 | "name": "python", 274 | "nbconvert_exporter": "python", 275 | "pygments_lexer": "ipython3", 276 | "version": "3.6.5" 277 | } 278 | }, 279 | "nbformat": 4, 280 | "nbformat_minor": 2 281 | } -------------------------------------------------------------------------------- /notebooks/03-sampling.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "lines_to_end_of_cell_marker": 2 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import os\n", 12 | "\n", 13 | "import numpy as np\n", 14 | "from eolearn.core import EOPatch, EOTask\n", 15 | "from fs_s3fs import S3FS\n", 16 | "from matplotlib import pyplot as plt\n", 17 | "from sentinelhub import BBox, SHConfig\n", 18 | "\n", 19 | "from sg_utils.processing import multiprocess" 20 | ] 21 | }, 22 | { 23 | "cell_type": "markdown", 24 | "metadata": {}, 25 | "source": [ 26 | "# Config" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": null, 32 | "metadata": {}, 33 | "outputs": [], 34 | "source": [ 35 | "config = SHConfig()\n", 36 | "config.instance_id = ''\n", 37 | "config.aws_access_key_id = ''\n", 38 | "config.aws_secret_access_key = ''" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": null, 44 | "metadata": {}, 45 | "outputs": [], 46 | "source": [ 47 | "filesystem = S3FS(bucket_name='',\n", 48 | " aws_access_key_id=config.aws_access_key_id,\n", 49 | " aws_secret_access_key=config.aws_secret_access_key)" 50 | ] 51 | }, 52 | { 53 | "cell_type": "markdown", 54 | "metadata": {}, 55 | "source": [ 56 | "# Execute sampling" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": null, 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "class SamplePatchlets(EOTask):\n", 66 | "\n", 67 | " MS4_DEIMOS_SCALING = 4\n", 68 | "\n", 69 | " def __init__(self, s2_patchlet_size: int, num_samples: int):\n", 70 | " self.s2_patchlet_size = s2_patchlet_size\n", 71 | " self.num_samples = num_samples\n", 72 | "\n", 73 | " def _calculate_sampled_bbox(self, bbox: BBox, r: int, c: int, s: int, resolution: float) -> BBox:\n", 74 | " return BBox(((bbox.min_x + resolution * c, bbox.max_y - resolution * (r + s)),\n", 75 | " (bbox.min_x + resolution * (c + s), bbox.max_y - resolution * r)),\n", 76 | " bbox.crs)\n", 77 | "\n", 78 | " def _sample_s2(self, eop: EOPatch, row: int, col: int, size: int, resolution: float = 10):\n", 79 | " sampled_eop = EOPatch(timestamp=eop.timestamp, scalar=eop.scalar, meta_info=eop.meta_info)\n", 80 | " sampled_eop.data['CLP'] = eop.data['CLP'][:, row:row + size, col:col + size, :]\n", 81 | " sampled_eop.mask['CLM'] = eop.mask['CLM'][:, row:row + size, col:col + size, :]\n", 82 | " sampled_eop.mask['IS_DATA'] = eop.mask['IS_DATA'][:, row:row + size, col:col + size, :]\n", 83 | " sampled_eop.data['BANDS'] = eop.data['BANDS'][:, row:row + size, col:col + size, :]\n", 84 | " sampled_eop.scalar_timeless['PATCHLET_LOC'] = np.array([row, col, size])\n", 85 | " sampled_eop.bbox = self._calculate_sampled_bbox(eop.bbox, r=row, c=col, s=size, resolution=resolution)\n", 86 | " sampled_eop.meta_info['size_x'] = size\n", 87 | " sampled_eop.meta_info['size_y'] = size\n", 88 | " return sampled_eop\n", 89 | "\n", 90 | " def _sample_deimos(self, eop: EOPatch, row: int, col: int, size: int, resolution: float = 2.5):\n", 91 | " sampled_eop = EOPatch(timestamp=eop.timestamp, scalar=eop.scalar, meta_info=eop.meta_info)\n", 92 | " sampled_eop.data['BANDS-DEIMOS'] = eop.data['BANDS-DEIMOS'][:, row:row + size, col:col + size, :]\n", 93 | " sampled_eop.mask['CLM'] = eop.mask['CLM'][:, row:row + size, col:col + size, :]\n", 94 | " sampled_eop.mask['IS_DATA'] = eop.mask['IS_DATA'][:, row:row + size, col:col + size, :]\n", 95 | "\n", 96 | " sampled_eop.scalar_timeless['PATCHLET_LOC'] = np.array([row, col, size])\n", 97 | "\n", 98 | " sampled_eop.bbox = self._calculate_sampled_bbox(eop.bbox, r=row, c=col, s=size, resolution=resolution)\n", 99 | " sampled_eop.meta_info['size_x'] = size\n", 100 | " sampled_eop.meta_info['size_y'] = size\n", 101 | " return sampled_eop\n", 102 | "\n", 103 | " def execute(self, eopatch_s2, eopatch_deimos, buffer=20, seed=42):\n", 104 | " _, n_rows, n_cols, _ = eopatch_s2.data['BANDS'].shape\n", 105 | " np.random.seed(seed)\n", 106 | " eops_out = []\n", 107 | "\n", 108 | " for patchlet_num in range(0, self.num_samples):\n", 109 | " row = np.random.randint(buffer, n_rows - self.s2_patchlet_size - buffer)\n", 110 | " col = np.random.randint(buffer, n_cols - self.s2_patchlet_size - buffer)\n", 111 | " sampled_s2 = self._sample_s2(eopatch_s2, row, col, self.s2_patchlet_size)\n", 112 | " sampled_deimos = self._sample_deimos(eopatch_deimos,\n", 113 | " row*self.MS4_DEIMOS_SCALING,\n", 114 | " col*self.MS4_DEIMOS_SCALING,\n", 115 | " self.s2_patchlet_size*self.MS4_DEIMOS_SCALING)\n", 116 | " eops_out.append((sampled_s2, sampled_deimos))\n", 117 | " return eops_out" 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": null, 123 | "metadata": {}, 124 | "outputs": [], 125 | "source": [ 126 | "def sample_patch(eop_path_s2: str, eop_path_deimos,\n", 127 | " sampled_folder_s2, sampled_folder_deimos,\n", 128 | " s2_patchlet_size, num_samples, filesystem, buffer=20) -> None:\n", 129 | "\n", 130 | " task = SamplePatchlets(s2_patchlet_size=s2_patchlet_size, num_samples=num_samples)\n", 131 | " eop_name = os.path.basename(eop_path_s2)\n", 132 | " try:\n", 133 | " eop_s2 = EOPatch.load(eop_path_s2, filesystem=filesystem, lazy_loading=True)\n", 134 | " eop_deimos = EOPatch.load(eop_path_deimos, filesystem=filesystem, lazy_loading=True)\n", 135 | " patchlets = task.execute(eop_s2, eop_deimos, buffer=buffer)\n", 136 | " for i, (patchlet_s2, patchlet_deimos) in enumerate(patchlets):\n", 137 | "\n", 138 | " patchlet_s2.save(os.path.join(sampled_folder_s2, f'{eop_name}_{i}'),\n", 139 | " filesystem=filesystem)\n", 140 | "\n", 141 | " patchlet_deimos.save(os.path.join(sampled_folder_deimos, f'{eop_name}_{i}'),\n", 142 | " filesystem=filesystem)\n", 143 | "\n", 144 | " except KeyError as e:\n", 145 | " print(f'Key error. Could not find key: {e}')\n", 146 | " except ValueError as e:\n", 147 | " print(f'Value error. Value does not exist: {e}')" 148 | ] 149 | }, 150 | { 151 | "cell_type": "code", 152 | "execution_count": null, 153 | "metadata": {}, 154 | "outputs": [], 155 | "source": [ 156 | "EOPS_S2 = ''\n", 157 | "EOPS_DEIMOS = ''\n", 158 | "\n", 159 | "SAMPLED_S2_PATH = ''\n", 160 | "SAMPLED_DEIMOS_3M_PATH = ''\n", 161 | "\n", 162 | "\n", 163 | "eop_names = filesystem.listdir(EOPS_DEIMOS) # Both folder have the same EOPatches\n", 164 | "\n", 165 | "\n", 166 | "def sample_single(eop_name):\n", 167 | " path_s2 = os.path.join(EOPS_S2, eop_name)\n", 168 | " path_deimos = os.path.join(EOPS_DEIMOS, eop_name)\n", 169 | "\n", 170 | " sample_patch(path_s2, path_deimos, SAMPLED_S2_PATH, SAMPLED_DEIMOS_3M_PATH,\n", 171 | " s2_patchlet_size=32, num_samples=140, filesystem=filesystem, buffer=20)\n", 172 | "\n", 173 | "\n", 174 | "multiprocess(sample_single, eop_names, max_workers=16)" 175 | ] 176 | }, 177 | { 178 | "cell_type": "markdown", 179 | "metadata": {}, 180 | "source": [ 181 | "# Look at an example" 182 | ] 183 | }, 184 | { 185 | "cell_type": "code", 186 | "execution_count": null, 187 | "metadata": {}, 188 | "outputs": [], 189 | "source": [ 190 | "sampled_s2 = EOPatch.load(os.path.join(SAMPLED_S2_PATH, 'eopatch-0000_122'), filesystem=filesystem)\n", 191 | "sampled_deimos = EOPatch.load(os.path.join(SAMPLED_DEIMOS_3M_PATH, 'eopatch-0000_122'), filesystem=filesystem)" 192 | ] 193 | }, 194 | { 195 | "cell_type": "code", 196 | "execution_count": null, 197 | "metadata": {}, 198 | "outputs": [], 199 | "source": [ 200 | "def _get_closest_timestamp_idx(eop, ref_timestamp):\n", 201 | " closest_idx = 0\n", 202 | " for i, ts in enumerate(eop.timestamp):\n", 203 | " if abs((ts - ref_timestamp).days) < abs((eop.timestamp[closest_idx] - ref_timestamp).days):\n", 204 | " closest_idx = i\n", 205 | " return closest_idx" 206 | ] 207 | }, 208 | { 209 | "cell_type": "code", 210 | "execution_count": null, 211 | "metadata": {}, 212 | "outputs": [], 213 | "source": [ 214 | "fig, ax = plt.subplots(ncols=2, figsize=(15, 15))\n", 215 | "idx_deimos = 1\n", 216 | "closest_idx = _get_closest_timestamp_idx(sampled_s2, sampled_deimos.timestamp[idx_deimos])\n", 217 | "\n", 218 | "ax[0].imshow(sampled_s2.data['BANDS'][closest_idx][..., [2, 1, 0]] / 10000*3.5)\n", 219 | "ax[1].imshow(sampled_deimos.data['BANDS-DEIMOS'][idx_deimos][..., [2, 1, 0]] / 12000)" 220 | ] 221 | } 222 | ], 223 | "metadata": { 224 | "kernelspec": { 225 | "display_name": "Environment (conda_tensorflow2_p36)", 226 | "language": "python", 227 | "name": "conda_tensorflow2_p36" 228 | }, 229 | "language_info": { 230 | "codemirror_mode": { 231 | "name": "ipython", 232 | "version": 3 233 | }, 234 | "file_extension": ".py", 235 | "mimetype": "text/x-python", 236 | "name": "python", 237 | "nbconvert_exporter": "python", 238 | "pygments_lexer": "ipython3", 239 | "version": "3.6.5" 240 | } 241 | }, 242 | "nbformat": 4, 243 | "nbformat_minor": 2 244 | } -------------------------------------------------------------------------------- /notebooks/04-sampled-to-npz.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "Warning! This notebook requires atleast 90GB of RAM" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import os" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "import numpy as np\n", 26 | "import pandas as pd\n", 27 | "from dateutil.parser import parse\n", 28 | "from eolearn.core import EOPatch\n", 29 | "from fs_s3fs import S3FS\n", 30 | "from sentinelhub import SHConfig" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": null, 36 | "metadata": { 37 | "lines_to_next_cell": 2 38 | }, 39 | "outputs": [], 40 | "source": [ 41 | "from sg_utils.processing import multiprocess" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": null, 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "config = SHConfig()\n", 51 | "config.aws_access_key_id = ''\n", 52 | "config.aws_secret_access_key = ''" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": null, 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "filesystem = S3FS(bucket_name='',\n", 62 | " aws_access_key_id=config.aws_access_key_id,\n", 63 | " aws_secret_access_key=config.aws_secret_access_key)" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": null, 69 | "metadata": { 70 | "lines_to_next_cell": 2 71 | }, 72 | "outputs": [], 73 | "source": [ 74 | "DIR_SAMPLED_S2 = ''\n", 75 | "DIR_SAMPLED_DEIMOS_1M = ''" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": null, 81 | "metadata": {}, 82 | "outputs": [], 83 | "source": [ 84 | "metadata_ms4 = pd.read_parquet(filesystem.openbin('metadata/deimos_ms4_metadata.pq'))\n", 85 | "metadata_ms4['Country'] = metadata_ms4.Projection_OGCWKT.apply(lambda x: 'Lithuania' if '34N' in x else 'Cyprus') # ! Warning, doesn't ! \n", 86 | "timestamp_country_map = {ts: country for ts, country in metadata_ms4[['START_TIME', 'Country']].values}" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": null, 92 | "metadata": {}, 93 | "outputs": [], 94 | "source": [ 95 | "MAX_CC = .05\n", 96 | "N_DAYS = 60\n", 97 | "\n", 98 | "S2_FACTOR = 10000.0" 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": null, 104 | "metadata": {}, 105 | "outputs": [], 106 | "source": [ 107 | "def normalize_deimos(eop, pan=False):\n", 108 | " bname = 'PANSHARPENED-DEIMOS' if pan else 'BANDS-DEIMOS'\n", 109 | "\n", 110 | " bands = eop.data[bname]\n", 111 | " for i, ts in enumerate(eop.timestamp):\n", 112 | " for chnl in range(0, 4):\n", 113 | "\n", 114 | " bands = bands.astype(np.float32)\n", 115 | " median = float(eop.meta_info['metadata'][ts]['MS4']['CLM_RADIANCE_BAND_STATS_PANSHARPENED']\n", 116 | " [f'STX_CLM_RADIANCE_MEDIAN_PANSHARPENED_{chnl+1}'])\n", 117 | " std = float(eop.meta_info['metadata'][ts]['MS4']['CLM_RADIANCE_BAND_STATS_PANSHARPENED']\n", 118 | " [f'STX_CLM_RADIANCE_STDV_PANSHARPENED_{chnl+1}'])\n", 119 | "\n", 120 | " gain = float(eop.meta_info['metadata'][ts]['MS4']['PHYSICAL_INFO'][f'PHYSICAL_GAIN_{chnl+1}'])\n", 121 | " bias = float(eop.meta_info['metadata'][ts]['MS4']['PHYSICAL_INFO'][f'PHYSICAL_BIAS_{chnl+1}'])\n", 122 | "\n", 123 | " bands[i, ..., chnl] = ((bands[i, ..., chnl]*gain + bias) - median) / std\n", 124 | "\n", 125 | " eop.data[bname] = bands\n", 126 | " return eop\n", 127 | "\n", 128 | "\n", 129 | "def _valid_idxs_deimos(eop, max_cc, clm_band=0, threshold=95):\n", 130 | "\n", 131 | " idxs = []\n", 132 | " for i, ts in enumerate(eop.timestamp):\n", 133 | "\n", 134 | " float(eop.meta_info['metadata'][ts]['MS4']['PHYSICAL_INFO'][f'PHYSICAL_GAIN_{clm_band+1}'])\n", 135 | " float(eop.meta_info['metadata'][ts]['MS4']['PHYSICAL_INFO'][f'PHYSICAL_BIAS_{clm_band+1}'])\n", 136 | "\n", 137 | " # cloud_coverage = ((eop.data['BANDS-DEIMOS'][i, ..., clm_band]*gain + bias) > threshold).mean()\n", 138 | " cloud_coverage = eop.mask['CLM'][i].mean()\n", 139 | " if cloud_coverage <= max_cc and (eop.mask['IS_DATA'].mean() == 1):\n", 140 | " idxs.append(i)\n", 141 | " return idxs\n", 142 | "\n", 143 | "\n", 144 | "def _filter_cloudy_s2(eop, max_cc):\n", 145 | " idxs = []\n", 146 | " for i, _ in enumerate(eop.timestamp):\n", 147 | " if (eop.mask['CLM'][i, ...].mean() <= max_cc) and (eop.mask['IS_DATA'].mean() == 1):\n", 148 | " idxs.append(i)\n", 149 | " eop.data['BANDS'] = eop.data['BANDS'][idxs, ...]\n", 150 | " eop.data['CLP'] = eop.data['CLP'][idxs, ...]\n", 151 | " eop.mask['CLM'] = eop.mask['CLM'][idxs, ...]\n", 152 | " eop.mask['IS_DATA'] = eop.mask['IS_DATA'][idxs, ...]\n", 153 | " eop.scalar['NORM_FACTORS'] = eop.scalar['NORM_FACTORS'][idxs, ...]\n", 154 | "\n", 155 | " eop.timestamp = list(np.array(eop.timestamp)[idxs])\n", 156 | " return eop\n", 157 | "\n", 158 | "\n", 159 | "def _get_closest_timestamp_idx(eop, ref_timestamp):\n", 160 | " closest_idx = 0\n", 161 | " for i, ts in enumerate(eop.timestamp):\n", 162 | " if abs((ts - ref_timestamp).days) < abs((eop.timestamp[closest_idx] - ref_timestamp).days):\n", 163 | " closest_idx = i\n", 164 | " return closest_idx\n", 165 | "\n", 166 | "\n", 167 | "def _idxs_within_n_days(eop, ref_ts, n_days=60):\n", 168 | " idxs = []\n", 169 | " for i, ts in enumerate(eop.timestamp):\n", 170 | " if 0 < (ref_ts - ts).days < 60:\n", 171 | " idxs.append(i)\n", 172 | " return idxs" 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": null, 178 | "metadata": {}, 179 | "outputs": [], 180 | "source": [ 181 | "DIR_SAMPLED_S2 = ''\n", 182 | "DIR_SAMPLED_DEIMOS = ''" 183 | ] 184 | }, 185 | { 186 | "cell_type": "code", 187 | "execution_count": null, 188 | "metadata": {}, 189 | "outputs": [], 190 | "source": [ 191 | "def construct_features_labels(eop_name):\n", 192 | " features, labels, s2_timestamps = [], [], []\n", 193 | " try:\n", 194 | "\n", 195 | " s2 = EOPatch.load(os.path.join(DIR_SAMPLED_S2, eop_name), filesystem=filesystem, lazy_loading=True)\n", 196 | " deimos = EOPatch.load(os.path.join(DIR_SAMPLED_DEIMOS, eop_name), filesystem=filesystem, lazy_loading=True)\n", 197 | " s2 = _filter_cloudy_s2(s2, MAX_CC)\n", 198 | " non_cloudy_idxs = _valid_idxs_deimos(deimos, MAX_CC)\n", 199 | " timestamps = np.array(deimos.timestamp)[non_cloudy_idxs]\n", 200 | "\n", 201 | " deimos_data = normalize_deimos(deimos, pan=False).data['BANDS-DEIMOS'][non_cloudy_idxs, ...]\n", 202 | " for ts, deim in zip(timestamps, deimos_data):\n", 203 | "\n", 204 | " s2_idxs = _idxs_within_n_days(s2, ts, N_DAYS)\n", 205 | "\n", 206 | " s2_timestamps.append(np.array(s2.timestamp)[s2_idxs])\n", 207 | " features.append(s2.data['BANDS'][s2_idxs, ...] / S2_FACTOR)\n", 208 | " labels.append(deim)\n", 209 | "\n", 210 | " return {'features': features, 'labels': labels,\n", 211 | " 'patchlet_name': [eop_name]*len(features),\n", 212 | " 'timestamps_deimos': timestamps,\n", 213 | " 'timestamps_s2': s2_timestamps,\n", 214 | " 'countries': [timestamp_country_map[ts] for ts in timestamps]\n", 215 | " }\n", 216 | " except Exception as e:\n", 217 | " print(f\"Failed for {eop_name} with error: {e}\")\n", 218 | " return {'features': [], 'labels': [],\n", 219 | " 'patchlet_name': [],\n", 220 | " 'timestamps_deimos': [],\n", 221 | " 'timestamps_s2': [],\n", 222 | " 'countries': []\n", 223 | " }" 224 | ] 225 | }, 226 | { 227 | "cell_type": "code", 228 | "execution_count": null, 229 | "metadata": { 230 | "lines_to_next_cell": 2 231 | }, 232 | "outputs": [], 233 | "source": [ 234 | "#sampled_list = filesystem.listdir(DIR_SAMPLED_S2)" 235 | ] 236 | }, 237 | { 238 | "cell_type": "code", 239 | "execution_count": null, 240 | "metadata": {}, 241 | "outputs": [], 242 | "source": [ 243 | "results = multiprocess(construct_features_labels, sampled_list, max_workers=47)" 244 | ] 245 | }, 246 | { 247 | "cell_type": "code", 248 | "execution_count": null, 249 | "metadata": {}, 250 | "outputs": [], 251 | "source": [ 252 | "filesystem.makedirs('')" 253 | ] 254 | }, 255 | { 256 | "cell_type": "code", 257 | "execution_count": null, 258 | "metadata": {}, 259 | "outputs": [], 260 | "source": [ 261 | "def save_npz(result):\n", 262 | " info = []\n", 263 | " for i, (feats, labels, patch_name, ts_deim, ts_s2, ts_country) in enumerate(zip(result['features'], result['labels'],\n", 264 | " result['patchlet_name'],\n", 265 | " result['timestamps_deimos'],\n", 266 | " result['timestamps_s2'],\n", 267 | " result['countries'])):\n", 268 | "\n", 269 | " if len(feats) == 0:\n", 270 | " continue\n", 271 | "\n", 272 | " filename = f'data_{patch_name}_{i}.npz'\n", 273 | " info.append(dict(patchlet=patch_name, eopatch=patch_name.split('_')[0],\n", 274 | " countries=ts_country, timestamp_deimos=ts_deim,\n", 275 | " timestamps_s2=ts_s2,\n", 276 | " singleton_npz_filename=filename))\n", 277 | " with filesystem.openbin(f'/{filename}', 'wb') as f:\n", 278 | " np.savez(f, features=feats,\n", 279 | " labels=labels,\n", 280 | " patchlet=patch_name,\n", 281 | " timetamps_deimos=ts_deim,\n", 282 | " timestamps_s2=ts_s2,\n", 283 | " countries=ts_country)\n", 284 | " return pd.DataFrame(info)" 285 | ] 286 | }, 287 | { 288 | "cell_type": "code", 289 | "execution_count": null, 290 | "metadata": {}, 291 | "outputs": [], 292 | "source": [ 293 | "dfs = multiprocess(save_npz, results, max_workers=47)" 294 | ] 295 | }, 296 | { 297 | "cell_type": "code", 298 | "execution_count": null, 299 | "metadata": {}, 300 | "outputs": [], 301 | "source": [ 302 | "npz_files = filesystem.listdir('')" 303 | ] 304 | }, 305 | { 306 | "cell_type": "code", 307 | "execution_count": null, 308 | "metadata": {}, 309 | "outputs": [], 310 | "source": [ 311 | "def create_info(filename):\n", 312 | " npz = np.load(filesystem.openbin(f'/{filename}'), allow_pickle=True)\n", 313 | " patchlet_name = npz['patchlet']\n", 314 | " eopatch_name = str(patchlet_name).split('_')[0]\n", 315 | " timestamp_deimos = npz['timetamps_deimos']\n", 316 | " timestamps_s2 = npz['timestamps_s2']\n", 317 | " countries = npz['countries']\n", 318 | " return dict(patchlet=patchlet_name,\n", 319 | " eopatch=eopatch_name,\n", 320 | " countries=countries,\n", 321 | " timestamp_deimos=timestamp_deimos,\n", 322 | " timestamps_s2=timestamps_s2,\n", 323 | " singleton_npz_filename=filename)" 324 | ] 325 | }, 326 | { 327 | "cell_type": "code", 328 | "execution_count": null, 329 | "metadata": {}, 330 | "outputs": [], 331 | "source": [ 332 | "dicts = multiprocess(create_info, npz_files, max_workers=16)" 333 | ] 334 | }, 335 | { 336 | "cell_type": "code", 337 | "execution_count": null, 338 | "metadata": {}, 339 | "outputs": [], 340 | "source": [ 341 | "df_concated = pd.concat(dfs)" 342 | ] 343 | }, 344 | { 345 | "cell_type": "code", 346 | "execution_count": null, 347 | "metadata": {}, 348 | "outputs": [], 349 | "source": [ 350 | "df = df_concated" 351 | ] 352 | }, 353 | { 354 | "cell_type": "code", 355 | "execution_count": null, 356 | "metadata": {}, 357 | "outputs": [], 358 | "source": [ 359 | "df['timestamps_s2_str'] = df.timestamps_s2.apply(lambda x: '|'.join([str(s) for s in x]))" 360 | ] 361 | }, 362 | { 363 | "cell_type": "code", 364 | "execution_count": null, 365 | "metadata": {}, 366 | "outputs": [], 367 | "source": [ 368 | "df.timestamps_s2_str = df.timestamps_s2_str.astype(str)" 369 | ] 370 | }, 371 | { 372 | "cell_type": "code", 373 | "execution_count": null, 374 | "metadata": { 375 | "lines_to_next_cell": 2 376 | }, 377 | "outputs": [], 378 | "source": [ 379 | "df['num_tstamps'] = df.timestamps_s2.apply(lambda x: len(x))" 380 | ] 381 | }, 382 | { 383 | "cell_type": "code", 384 | "execution_count": null, 385 | "metadata": { 386 | "lines_to_next_cell": 2 387 | }, 388 | "outputs": [], 389 | "source": [ 390 | "df[['patchlet', 'eopatch', 'countries', 'timestamp_deimos',\n", 391 | " 'singleton_npz_filename', 'timestamps_s2_str']].dtypes" 392 | ] 393 | }, 394 | { 395 | "cell_type": "code", 396 | "execution_count": null, 397 | "metadata": {}, 398 | "outputs": [], 399 | "source": [ 400 | "df.timestamp_deimos = df.timestamp_deimos.apply(lambda x: parse(str(x)))" 401 | ] 402 | }, 403 | { 404 | "cell_type": "code", 405 | "execution_count": null, 406 | "metadata": {}, 407 | "outputs": [], 408 | "source": [ 409 | "df.countries = df.countries.astype(str)\n", 410 | "df.patchlet = df.patchlet.astype(str)" 411 | ] 412 | }, 413 | { 414 | "cell_type": "code", 415 | "execution_count": null, 416 | "metadata": {}, 417 | "outputs": [], 418 | "source": [ 419 | "df" 420 | ] 421 | }, 422 | { 423 | "cell_type": "code", 424 | "execution_count": null, 425 | "metadata": {}, 426 | "outputs": [], 427 | "source": [ 428 | "with filesystem.openbin('metadata/npz_info_small.pq', 'wb') as f:\n", 429 | " df[['patchlet', 'eopatch', 'countries', 'timestamp_deimos',\n", 430 | " 'singleton_npz_filename', 'timestamps_s2_str', 'num_tstamps']].to_parquet(f)" 431 | ] 432 | } 433 | ], 434 | "metadata": { 435 | "kernelspec": { 436 | "display_name": "Python 3.6.9 64-bit ('venv': virtualenv)", 437 | "language": "python", 438 | "name": "python369jvsc74a57bd008539c228c0b1d46fd3ab380299090bd67be578e8cdd5c516ba9f15efc81c90d" 439 | } 440 | }, 441 | "nbformat": 4, 442 | "nbformat_minor": 2 443 | } -------------------------------------------------------------------------------- /notebooks/05a-train-test-split.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "lines_to_next_cell": 2 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import numpy as np\n", 12 | "import pandas as pd\n", 13 | "from fs_s3fs import S3FS" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": null, 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "aws_access_key_id = ''\n", 23 | "aws_secret_access_key = ''\n", 24 | "\n", 25 | "filesystem = S3FS(bucket_name='',\n", 26 | " aws_access_key_id=aws_access_key_id,\n", 27 | " aws_secret_access_key=aws_secret_access_key)" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": null, 33 | "metadata": {}, 34 | "outputs": [], 35 | "source": [ 36 | "data_df = pd.read_parquet(filesystem.openbin('metadata/npz_info_small.pq'))" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": null, 42 | "metadata": {}, 43 | "outputs": [], 44 | "source": [ 45 | "eops = data_df.eopatch.unique()" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": null, 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "eops_train = np.random.choice(eops, size=int(len(eops)*0.6), replace=False)\n", 55 | "eops_val_test = list(set(eops) - set(eops_train))\n", 56 | "eops_test = np.random.choice(eops_val_test, size=int(len(eops_val_test)*0.5), replace=False)\n", 57 | "eops_val = list(set(eops) - set(eops_train) - set(eops_test))" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": null, 63 | "metadata": {}, 64 | "outputs": [], 65 | "source": [ 66 | "assert set(eops_train).intersection(set(eops_test)) == set()\n", 67 | "assert set(eops_train).intersection(set(eops_val)) == set()\n", 68 | "assert set(eops_val).intersection(set(eops_test)) == set()\n", 69 | "assert set(eops_val).union(set(eops_test)).union(set(eops_train)) == set(eops)" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": null, 75 | "metadata": {}, 76 | "outputs": [], 77 | "source": [ 78 | "def set_train_test_val(eopatch, train_set, val_set, test_set):\n", 79 | " if eopatch in train_set:\n", 80 | " return 'train'\n", 81 | " elif eopatch in val_set:\n", 82 | " return 'validation'\n", 83 | " elif eopatch in test_set:\n", 84 | " return 'test'\n", 85 | " raise ValueError(f\"Could not find eopatch: {eopatch} in train/test/validation sets.\")" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": null, 91 | "metadata": {}, 92 | "outputs": [], 93 | "source": [ 94 | "data_df['train_test_validation'] = data_df.eopatch.apply(\n", 95 | " lambda x: set_train_test_val(x, eops_train, eops_val, eops_test))" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": null, 101 | "metadata": {}, 102 | "outputs": [], 103 | "source": [ 104 | "with filesystem.openbin('metadata/npz_info_small.pq', 'wb') as f:\n", 105 | " data_df.to_parquet(f)" 106 | ] 107 | } 108 | ], 109 | "metadata": { 110 | "kernelspec": { 111 | "display_name": "Environment (conda_pytorch_p36)", 112 | "language": "python", 113 | "name": "conda_pytorch_p36" 114 | } 115 | }, 116 | "nbformat": 4, 117 | "nbformat_minor": 2 118 | } -------------------------------------------------------------------------------- /notebooks/05b-find-cloudy-neighbours.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": { 16 | "lines_to_next_cell": 2 17 | }, 18 | "outputs": [], 19 | "source": [ 20 | "import geopandas as gpd\n", 21 | "import numpy as np\n", 22 | "import pandas as pd\n", 23 | "from eolearn.core import EOPatch\n", 24 | "from fs_s3fs import S3FS\n", 25 | "from sentinelhub import CRS, SHConfig\n", 26 | "from tqdm.auto import tqdm" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": null, 32 | "metadata": {}, 33 | "outputs": [], 34 | "source": [ 35 | "config = SHConfig()\n", 36 | "config.aws_access_key_id = ''\n", 37 | "config.aws_secret_access_key = ''" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": null, 43 | "metadata": {}, 44 | "outputs": [], 45 | "source": [ 46 | "filesystem = S3FS(bucket_name='',\n", 47 | " aws_access_key_id=config.aws_access_key_id,\n", 48 | " aws_secret_access_key=config.aws_secret_access_key)" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": null, 54 | "metadata": {}, 55 | "outputs": [], 56 | "source": [ 57 | "data_df = pd.read_parquet(filesystem.openbin('metadata/npz_info_small.pq'))" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": null, 63 | "metadata": {}, 64 | "outputs": [], 65 | "source": [ 66 | "data_df" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": null, 72 | "metadata": {}, 73 | "outputs": [], 74 | "source": [ 75 | "DIR_DEIMOS = ''" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": null, 81 | "metadata": {}, 82 | "outputs": [], 83 | "source": [ 84 | "MAX_CC = 0.05" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": null, 90 | "metadata": {}, 91 | "outputs": [], 92 | "source": [ 93 | "def cloudy_idxs_deimos(eop, max_cc, threshold=100):\n", 94 | "\n", 95 | " idxs = []\n", 96 | " for i, ts in enumerate(eop.timestamp):\n", 97 | "\n", 98 | " float(eop.meta_info['metadata'][ts]['MS4']['PHYSICAL_INFO'][f'PHYSICAL_GAIN_4'])\n", 99 | " float(eop.meta_info['metadata'][ts]['MS4']['PHYSICAL_INFO'][f'PHYSICAL_BIAS_4'])\n", 100 | " is_data_mask = eop.mask['IS_DATA'][i].squeeze()\n", 101 | " cloud_coverage = eop.mask['CLM'][i][is_data_mask].mean()\n", 102 | " if cloud_coverage > max_cc:\n", 103 | " idxs.append(i)\n", 104 | "\n", 105 | " return idxs" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": null, 111 | "metadata": {}, 112 | "outputs": [], 113 | "source": [ 114 | "cloud_info = []\n", 115 | "for eop_name in tqdm(filesystem.listdir(DIR_DEIMOS)):\n", 116 | " eop = EOPatch.load(os.path.join(DIR_DEIMOS, eop_name), filesystem=filesystem, lazy_loading=True)\n", 117 | " cloudy = cloudy_idxs_deimos(eop, MAX_CC)\n", 118 | " info = dict(eop_name=eop_name,\n", 119 | " bbox=eop.bbox,\n", 120 | " geometry=eop.bbox.geometry,\n", 121 | " crs=eop.bbox.crs,\n", 122 | " cloudy_timestamps=np.array(eop.timestamp)[cloudy])\n", 123 | " cloud_info.append(info)" 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": null, 129 | "metadata": {}, 130 | "outputs": [], 131 | "source": [ 132 | "df = pd.DataFrame(cloud_info)\n", 133 | "df.to_pickle('cloud_info.pkl')" 134 | ] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "execution_count": null, 139 | "metadata": {}, 140 | "outputs": [], 141 | "source": [ 142 | "gdf_cyprus = gpd.GeoDataFrame(df[df.crs == CRS('32636')], crs='EPSG:32636')\n", 143 | "gdf_lithuania = gpd.GeoDataFrame(df[df.crs == CRS('32634')], crs='EPSG:32634')" 144 | ] 145 | }, 146 | { 147 | "cell_type": "code", 148 | "execution_count": null, 149 | "metadata": {}, 150 | "outputs": [], 151 | "source": [ 152 | "gdf_cyprus" 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": null, 158 | "metadata": {}, 159 | "outputs": [], 160 | "source": [ 161 | "def get_neighbouring_eops(gdf):\n", 162 | " for index, row in gdf.iterrows():\n", 163 | " # get 'not disjoint' countries\n", 164 | " neighbors = gdf[~gdf.geometry.disjoint(row.geometry)].eop_name.tolist()\n", 165 | "\n", 166 | " # remove own name of the country from the list\n", 167 | " neighbors = [name for name in neighbors if row.eop_name != name]\n", 168 | "\n", 169 | " # add names of neighbors as NEIGHBORS value\n", 170 | " gdf.at[index, \"neighbouring_eops\"] = \", \".join(neighbors)\n", 171 | " return gdf" 172 | ] 173 | }, 174 | { 175 | "cell_type": "code", 176 | "execution_count": null, 177 | "metadata": {}, 178 | "outputs": [], 179 | "source": [ 180 | "gdf_lithuania = get_neighbouring_eops(gdf_lithuania)\n", 181 | "gdf_cyprus = get_neighbouring_eops(gdf_cyprus)" 182 | ] 183 | }, 184 | { 185 | "cell_type": "code", 186 | "execution_count": null, 187 | "metadata": {}, 188 | "outputs": [], 189 | "source": [ 190 | "gdf_cyprus.head(300)" 191 | ] 192 | }, 193 | { 194 | "cell_type": "code", 195 | "execution_count": null, 196 | "metadata": {}, 197 | "outputs": [], 198 | "source": [ 199 | "def get_cloudy_eop_timestamps(gdf):\n", 200 | " cloudy_eop_timestamps = []\n", 201 | " for _, row in gdf.iterrows():\n", 202 | " for cloudy_timestamp in row.cloudy_timestamps:\n", 203 | " for neighbour in row.neighbouring_eops.split(','):\n", 204 | " cloudy_eop_timestamps.append((cloudy_timestamp, neighbour.strip()))\n", 205 | " return set(cloudy_eop_timestamps)" 206 | ] 207 | }, 208 | { 209 | "cell_type": "code", 210 | "execution_count": null, 211 | "metadata": {}, 212 | "outputs": [], 213 | "source": [ 214 | "cloudy_cyprus = get_cloudy_eop_timestamps(gdf_cyprus)\n", 215 | "cloudy_lithuania = get_cloudy_eop_timestamps(gdf_lithuania)" 216 | ] 217 | }, 218 | { 219 | "cell_type": "code", 220 | "execution_count": null, 221 | "metadata": {}, 222 | "outputs": [], 223 | "source": [ 224 | "def is_shadow_v2(eopatch, timestamp_deimos, country):\n", 225 | "\n", 226 | " if country == 'Lithuania':\n", 227 | " return (timestamp_deimos.to_pydatetime(), eopatch) in cloudy_lithuania\n", 228 | " elif country == 'Cyprus':\n", 229 | " return (timestamp_deimos.to_pydatetime(), eopatch) in cloudy_cyprus\n", 230 | " else:\n", 231 | " raise ValueError(\"Wrong country\")" 232 | ] 233 | }, 234 | { 235 | "cell_type": "code", 236 | "execution_count": null, 237 | "metadata": {}, 238 | "outputs": [], 239 | "source": [ 240 | "data_df" 241 | ] 242 | }, 243 | { 244 | "cell_type": "code", 245 | "execution_count": null, 246 | "metadata": {}, 247 | "outputs": [], 248 | "source": [ 249 | "data_df['is_shadow_v2'] = data_df.apply(lambda x: is_shadow_v2(x.eopatch, x.timestamp_deimos, x.countries), axis=1)" 250 | ] 251 | }, 252 | { 253 | "cell_type": "code", 254 | "execution_count": null, 255 | "metadata": {}, 256 | "outputs": [], 257 | "source": [ 258 | "data_df.is_shadow_v2.value_counts()" 259 | ] 260 | }, 261 | { 262 | "cell_type": "code", 263 | "execution_count": null, 264 | "metadata": {}, 265 | "outputs": [], 266 | "source": [ 267 | "with filesystem.openbin('metadata/npz_info_small.pq', 'wb') as f:\n", 268 | " data_df.to_parquet(f)" 269 | ] 270 | } 271 | ], 272 | "metadata": { 273 | "kernelspec": { 274 | "display_name": "Environment (conda_tensorflow2_p36)", 275 | "language": "python", 276 | "name": "conda_tensorflow2_p36" 277 | } 278 | }, 279 | "nbformat": 4, 280 | "nbformat_minor": 2 281 | } -------------------------------------------------------------------------------- /notebooks/05c-calculate-s2-normalizations.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "from eolearn.core import EOPatch\n", 11 | "from sentinelhub import SHConfig\n", 12 | "from fs_s3fs import S3FS\n", 13 | "import pandas as pd\n", 14 | "from matplotlib import pyplot as plt \n", 15 | "from collections import defaultdict\n", 16 | "from datetime import datetime\n", 17 | "import seaborn as sns\n", 18 | "from tqdm.auto import tqdm\n", 19 | "import os " 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": null, 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "config = SHConfig()\n", 29 | "config.aws_access_key_id = ''\n", 30 | "config.aws_secret_access_key = ''" 31 | ] 32 | }, 33 | { 34 | "cell_type": "markdown", 35 | "metadata": {}, 36 | "source": [ 37 | "# Per timestamp" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": null, 43 | "metadata": {}, 44 | "outputs": [], 45 | "source": [ 46 | "\n", 47 | "filesystem = S3FS(bucket_name='', \n", 48 | " aws_access_key_id=config.aws_access_key_id, \n", 49 | " aws_secret_access_key=config.aws_secret_access_key)" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": null, 55 | "metadata": {}, 56 | "outputs": [], 57 | "source": [ 58 | "NPZ_LOC = ''" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": null, 64 | "metadata": {}, 65 | "outputs": [], 66 | "source": [ 67 | "metadata_ms4 = pd.read_parquet(filesystem.openbin('metadata/deimos_ms4_metadata.pq'))\n", 68 | "metadata_pan = pd.read_parquet(filesystem.openbin('metadata/deimos_pan_metadata.pq'))" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": null, 74 | "metadata": {}, 75 | "outputs": [], 76 | "source": [ 77 | "metadata_ms4['country'] = metadata_ms4.Projection_OGCWKT.apply(lambda x: 'Lithuania' if '34N' in x else 'Cyprus')" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": null, 83 | "metadata": {}, 84 | "outputs": [], 85 | "source": [ 86 | "timestamp_data_map = defaultdict(list)\n", 87 | "for npz_file in tqdm(chosen_samples):\n", 88 | " npz = np.load(filesystem.openbin(f'{NPZ_LOC}/{npz_file}'), allow_pickle=True)\n", 89 | " timestamp_data_map[npz['timetamps_deimos'].item()].append(npz['features'])" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": null, 95 | "metadata": {}, 96 | "outputs": [], 97 | "source": [ 98 | "df_dicts = [] \n", 99 | "for ts, ts_values in timestamp_data_map.items():\n", 100 | " joined = np.concatenate(ts_values)\n", 101 | " mean = np.mean(joined, axis=(0, 1, 2))\n", 102 | " median = np.median(joined, axis=(0, 1, 2))\n", 103 | " std = np.std(joined, axis=(0, 1, 2))\n", 104 | " \n", 105 | " df_dicts.append({'timestamp': ts, 'mean': mean, 'median': median, 'std': std})\n", 106 | "df_norm_s2 = pd.DataFrame(df_dicts)" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": null, 112 | "metadata": {}, 113 | "outputs": [], 114 | "source": [ 115 | "for i in range(0, 4): \n", 116 | " df_norm_s2[f'MEAN_{i}'] = df_norm_s2['mean'].apply(lambda x: x[i])\n", 117 | " df_norm_s2[f'STD_{i}'] = df_norm_s2['std'].apply(lambda x: x[i])\n", 118 | " df_norm_s2[f'MEDIAN_{i}'] = df_norm_s2['median'].apply(lambda x: x[i])" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": null, 124 | "metadata": {}, 125 | "outputs": [], 126 | "source": [ 127 | "df_norm_s2_per_timestamp = df_norm_s2.set_index('timestamp').join(metadata_ms4[['START_TIME', 'country']].set_index('START_TIME')).reset_index()" 128 | ] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "execution_count": null, 133 | "metadata": {}, 134 | "outputs": [], 135 | "source": [ 136 | "fg = sns.FacetGrid(data=df_norm_s2_per_timestamp, hue='country', aspect=2.5, size=6)\n", 137 | "fg.map(plt.scatter, 'timestamp', 'MEDIAN_2').add_legend()\n" 138 | ] 139 | }, 140 | { 141 | "cell_type": "markdown", 142 | "metadata": {}, 143 | "source": [ 144 | "# Per country" 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": null, 150 | "metadata": {}, 151 | "outputs": [], 152 | "source": [ 153 | "timestamp_country_map = {ts: country for ts,country in metadata_ms4[['START_TIME', 'country']].values}" 154 | ] 155 | }, 156 | { 157 | "cell_type": "code", 158 | "execution_count": null, 159 | "metadata": {}, 160 | "outputs": [], 161 | "source": [ 162 | "country_data_map = defaultdict(list)" 163 | ] 164 | }, 165 | { 166 | "cell_type": "code", 167 | "execution_count": null, 168 | "metadata": {}, 169 | "outputs": [], 170 | "source": [ 171 | "NPZ_LOC = ''" 172 | ] 173 | }, 174 | { 175 | "cell_type": "code", 176 | "execution_count": null, 177 | "metadata": {}, 178 | "outputs": [], 179 | "source": [ 180 | "sample_filenames = os.listdir(NPZ_LOC)\n", 181 | "chosen_samples = np.random.choice(sample_filenames, int(len(sample_filenames)*0.1), replace=False)" 182 | ] 183 | }, 184 | { 185 | "cell_type": "code", 186 | "execution_count": null, 187 | "metadata": {}, 188 | "outputs": [], 189 | "source": [ 190 | "for npz_file in tqdm(chosen_samples):\n", 191 | " npz = np.load(f'{NPZ_LOC}/{npz_file}', allow_pickle=True)\n", 192 | " \n", 193 | " country = timestamp_country_map[npz['timetamps_deimos'].item()]\n", 194 | " country_data_map[country].append(npz['features'])" 195 | ] 196 | }, 197 | { 198 | "cell_type": "code", 199 | "execution_count": null, 200 | "metadata": {}, 201 | "outputs": [], 202 | "source": [ 203 | "df_dicts = [] \n", 204 | "for country, country_values in country_data_map.items():\n", 205 | " joined = np.concatenate(country_values)\n", 206 | " mean = np.mean(joined, axis=(0, 1, 2))\n", 207 | " median = np.median(joined, axis=(0, 1, 2))\n", 208 | " std = np.std(joined, axis=(0, 1, 2))\n", 209 | " \n", 210 | " df_dicts.append({'country': country, \n", 211 | " 'mean_0': mean[0], 'mean_1': mean[1], 'mean_2': mean[2], 'mean_3': mean[3],\n", 212 | " 'median_0': median[0], 'median_1': median[1], 'median_2': median[2], 'median_3': median[3],\n", 213 | " 'std_0': std[0], 'std_1': std[1], 'std_2': std[2], 'std_3': std[3]})\n", 214 | "\n", 215 | "df_norm_s2_per_country = pd.DataFrame(df_dicts)" 216 | ] 217 | }, 218 | { 219 | "cell_type": "code", 220 | "execution_count": null, 221 | "metadata": {}, 222 | "outputs": [], 223 | "source": [ 224 | "df_norm_s2_per_country" 225 | ] 226 | }, 227 | { 228 | "cell_type": "code", 229 | "execution_count": null, 230 | "metadata": {}, 231 | "outputs": [], 232 | "source": [ 233 | "df_norm_s2_per_country.to_parquet(filesystem.openbin('metadata/s2_norm_per_country.pq', 'wb'))" 234 | ] 235 | } 236 | ], 237 | "metadata": { 238 | "kernelspec": { 239 | "display_name": "Environment (conda_tensorflow2_p36)", 240 | "language": "python", 241 | "name": "conda_tensorflow2_p36" 242 | }, 243 | "language_info": { 244 | "codemirror_mode": { 245 | "name": "ipython", 246 | "version": 3 247 | }, 248 | "file_extension": ".py", 249 | "mimetype": "text/x-python", 250 | "name": "python", 251 | "nbconvert_exporter": "python", 252 | "pygments_lexer": "ipython3", 253 | "version": "3.6.5" 254 | } 255 | }, 256 | "nbformat": 4, 257 | "nbformat_minor": 2 258 | } -------------------------------------------------------------------------------- /notebooks/05d-calculate-scores.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%reload_ext autoreload\n", 10 | "%autoreload 2\n", 11 | "\n", 12 | "%cd /home/ubuntu/dione-sr/" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": null, 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "import os" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": null, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "import imageio\n", 31 | "import numpy as np\n", 32 | "import pandas as pd\n", 33 | "from fs_s3fs import S3FS\n", 34 | "from matplotlib import pyplot as plt\n", 35 | "from tqdm.auto import tqdm" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": null, 41 | "metadata": { 42 | "lines_to_next_cell": 2 43 | }, 44 | "outputs": [], 45 | "source": [ 46 | "import cv2 as cv\n", 47 | "from hrnet.src.train import resize_batch_images\n", 48 | "from sr.data_loader import ImagesetDataset\n", 49 | "from sr.metrics import METRICS, minshift_loss\n", 50 | "from torch.utils.data import DataLoader" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": null, 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "aws_access_key_id = ''\n", 60 | "aws_secret_access_key = ''" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": null, 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [ 69 | "filesystem = S3FS(bucket_name='',\n", 70 | " aws_access_key_id=aws_access_key_id,\n", 71 | " aws_secret_access_key=aws_secret_access_key)" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": null, 77 | "metadata": {}, 78 | "outputs": [], 79 | "source": [ 80 | "norm_deimos = {k: v for k, v in np.load(filesystem.openbin('metadata/deimos_min_max_norm.npz')).items()}\n", 81 | "norm_s2 = {k: v for k, v in np.load(filesystem.openbin('metadata/s2_min_max_norm.npz')).items()}\n", 82 | "\n", 83 | "data_df = pd.read_parquet(filesystem.openbin('metadata/npz_info_small.pq'))\n", 84 | "country_norm_df = pd.read_parquet(filesystem.openbin('metadata/s2_norm_per_country.pq'))" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": null, 90 | "metadata": {}, 91 | "outputs": [], 92 | "source": [ 93 | "NPZ_FOLDER = ''" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": null, 99 | "metadata": { 100 | "lines_to_next_cell": 2 101 | }, 102 | "outputs": [], 103 | "source": [ 104 | "data_df.head()" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": null, 110 | "metadata": {}, 111 | "outputs": [], 112 | "source": [ 113 | "dataset = ImagesetDataset(imset_dir=NPZ_FOLDER,\n", 114 | " imset_npz_files=data_df.singleton_npz_filename.values,\n", 115 | " filesystem=filesystem,\n", 116 | " country_norm_df=country_norm_df,\n", 117 | " normalize=True,\n", 118 | " norm_deimos_npz=norm_deimos,\n", 119 | " norm_s2_npz=norm_s2,\n", 120 | " time_first=True\n", 121 | " )\n", 122 | "\n", 123 | "dataloader = DataLoader(dataset,\n", 124 | " batch_size=256,\n", 125 | " shuffle=False,\n", 126 | " num_workers=16,\n", 127 | " pin_memory=True)" 128 | ] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "execution_count": null, 133 | "metadata": {}, 134 | "outputs": [], 135 | "source": [ 136 | "SHIFTS = 6" 137 | ] 138 | }, 139 | { 140 | "cell_type": "markdown", 141 | "metadata": {}, 142 | "source": [ 143 | "### test run on a single batch" 144 | ] 145 | }, 146 | { 147 | "cell_type": "code", 148 | "execution_count": null, 149 | "metadata": {}, 150 | "outputs": [], 151 | "source": [ 152 | "batch = next(iter(dataloader))" 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": null, 158 | "metadata": {}, 159 | "outputs": [], 160 | "source": [ 161 | "batch.keys()" 162 | ] 163 | }, 164 | { 165 | "cell_type": "code", 166 | "execution_count": null, 167 | "metadata": {}, 168 | "outputs": [], 169 | "source": [ 170 | "lrs = batch['lr']\n", 171 | "hrs = batch['hr']\n", 172 | "names = batch['name']\n", 173 | "alphas = batch['alphas']" 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": null, 179 | "metadata": {}, 180 | "outputs": [], 181 | "source": [ 182 | "interpolated = resize_batch_images(lrs[:, -1, [-1], ...],\n", 183 | " fx=3, fy=3, interpolation=cv.INTER_CUBIC)" 184 | ] 185 | }, 186 | { 187 | "cell_type": "code", 188 | "execution_count": null, 189 | "metadata": {}, 190 | "outputs": [], 191 | "source": [ 192 | "mse = METRICS['MSE'](hrs[:, [-1], ...], interpolated.float())\n", 193 | "mse_shift, mse_ids = minshift_loss(hrs[:, [-1], ...], interpolated.float(),\n", 194 | " shifts=SHIFTS, metric='MSE')\n", 195 | "mse_shift_c, mse_ids_c = minshift_loss(hrs[:, [-1], ...], interpolated.float(),\n", 196 | " metric='MSE', shifts=SHIFTS, apply_correction=True)" 197 | ] 198 | }, 199 | { 200 | "cell_type": "code", 201 | "execution_count": null, 202 | "metadata": {}, 203 | "outputs": [], 204 | "source": [ 205 | "fig, ax = plt.subplots(figsize=(10, 10))\n", 206 | "ax.scatter(mse_shift.numpy(), mse_shift_c.numpy(), alpha=.3, label='MSE shifted corrected')\n", 207 | "ax.scatter(mse_shift.numpy(), mse.numpy(), alpha=.3, label='MSE')\n", 208 | "ax.plot([0, 1], [0, 1], 'k')\n", 209 | "ax.grid()\n", 210 | "ax.legend()\n", 211 | "ax.set_xlabel('MSE shifted')" 212 | ] 213 | }, 214 | { 215 | "cell_type": "code", 216 | "execution_count": null, 217 | "metadata": {}, 218 | "outputs": [], 219 | "source": [ 220 | "np.where(mse_shift_c.numpy() > .35)" 221 | ] 222 | }, 223 | { 224 | "cell_type": "code", 225 | "execution_count": null, 226 | "metadata": {}, 227 | "outputs": [], 228 | "source": [ 229 | "np.where(mse_shift_c.numpy() < .02)" 230 | ] 231 | }, 232 | { 233 | "cell_type": "code", 234 | "execution_count": null, 235 | "metadata": {}, 236 | "outputs": [], 237 | "source": [ 238 | "idx = 224\n", 239 | "\n", 240 | "img_de = hrs[idx, [-1], ...].numpy().squeeze()\n", 241 | "img_s2 = interpolated[idx].numpy().squeeze()" 242 | ] 243 | }, 244 | { 245 | "cell_type": "code", 246 | "execution_count": null, 247 | "metadata": {}, 248 | "outputs": [], 249 | "source": [ 250 | "ids = mse_ids_c[idx, :].numpy().astype(np.uint8)\n", 251 | "print(ids)\n", 252 | "\n", 253 | "img_s2 = img_s2[SHIFTS//2:-SHIFTS//2, SHIFTS//2:-SHIFTS//2]\n", 254 | "img_de = img_de[ids[0]:ids[1], ids[2]:ids[3]]\n", 255 | "\n", 256 | "img_s2 = 255*(img_s2-img_s2.min())/(img_s2.max()-img_s2.min())\n", 257 | "img_de = 255*(img_de-img_de.min())/(img_de.max()-img_de.min())\n", 258 | "\n", 259 | "giffile = f's2-deimos-{names[idx]}.gif'\n", 260 | "imageio.mimsave(giffile,\n", 261 | " [img_s2.astype(np.uint8), img_de.astype(np.uint8)],\n", 262 | " duration=0.5)" 263 | ] 264 | }, 265 | { 266 | "cell_type": "markdown", 267 | "metadata": {}, 268 | "source": [ 269 | "## Compute scores on entire dataset of patchlets" 270 | ] 271 | }, 272 | { 273 | "cell_type": "code", 274 | "execution_count": null, 275 | "metadata": {}, 276 | "outputs": [], 277 | "source": [ 278 | "pq_filename = 'scores-bicubic-32x32.pq'\n", 279 | "\n", 280 | "if not os.path.exists(pq_filename):\n", 281 | "\n", 282 | " scores = []\n", 283 | " for sample in tqdm(dataloader):\n", 284 | " hrs = sample['hr'][:, [-1], ...]\n", 285 | "\n", 286 | " interpolated = resize_batch_images(sample['lr'][:, -1, [-1], ...],\n", 287 | " fx=3, fy=3, interpolation=cv.INTER_CUBIC)\n", 288 | " mse_ = METRICS['MSE'](hrs.float(), interpolated.float())\n", 289 | " mse_shift, _ = minshift_loss(hrs.float(), interpolated.float(),\n", 290 | " metric='MSE', shifts=SHIFTS)\n", 291 | " mse_shift_c, _ = minshift_loss(hrs.float(), interpolated.float(),\n", 292 | " metric='MSE', shifts=SHIFTS, apply_correction=True)\n", 293 | " psnr_shift_c, _ = minshift_loss(hrs.float(), interpolated.float(),\n", 294 | " metric='PSNR', shifts=SHIFTS, apply_correction=True)\n", 295 | " ssim_shift_c, _ = minshift_loss(hrs.float(), interpolated.float(),\n", 296 | " metric='SSIM', shifts=SHIFTS, apply_correction=True)\n", 297 | "\n", 298 | " for name, mse, mse_s, mse_sc, psnr, ssim in zip(sample['name'],\n", 299 | " mse_,\n", 300 | " mse_shift,\n", 301 | " mse_shift_c,\n", 302 | " psnr_shift_c,\n", 303 | " ssim_shift_c):\n", 304 | " scores.append({'name': name,\n", 305 | " 'MSE': mse.numpy().astype(np.float32),\n", 306 | " 'MSE_s': mse_s.numpy().astype(np.float32),\n", 307 | " 'MSE_s_c': mse_sc.numpy().astype(np.float32),\n", 308 | " 'PSNR_s_c': psnr.numpy().astype(np.float32),\n", 309 | " 'SSIM_s_c': ssim.numpy().astype(np.float32)})\n", 310 | "\n", 311 | " df = pd.DataFrame(scores)\n", 312 | " print(len(df))\n", 313 | "\n", 314 | " df.MSE = df.MSE.astype(np.float32)\n", 315 | " df.MSE_s = df.MSE_s.astype(np.float32)\n", 316 | " df.MSE_s_c = df.MSE_s_c.astype(np.float32)\n", 317 | " df.PSNR_s_c = df.PSNR_s_c.astype(np.float32)\n", 318 | " df.SSIM_s_c = df.SSIM_s_c.astype(np.float32)\n", 319 | "\n", 320 | " df.to_parquet(pq_filename)\n", 321 | "else:\n", 322 | " df = pd.read_parquet(pq_filename)" 323 | ] 324 | }, 325 | { 326 | "cell_type": "code", 327 | "execution_count": null, 328 | "metadata": {}, 329 | "outputs": [], 330 | "source": [ 331 | "len(df)" 332 | ] 333 | }, 334 | { 335 | "cell_type": "code", 336 | "execution_count": null, 337 | "metadata": {}, 338 | "outputs": [], 339 | "source": [ 340 | "df.head()" 341 | ] 342 | }, 343 | { 344 | "cell_type": "code", 345 | "execution_count": null, 346 | "metadata": {}, 347 | "outputs": [], 348 | "source": [ 349 | "fig, ax = plt.subplots(figsize=(15, 10))\n", 350 | "df.MSE.hist(ax=ax, alpha=.3, bins=50, range=(0, 1), label='MSE')\n", 351 | "df.MSE_s.hist(ax=ax, alpha=.3, bins=50, range=(0, 1), label='MSE_s')\n", 352 | "df.MSE_s_c.hist(ax=ax, alpha=.3, bins=50, range=(0, 1), label='MSE_s_c')\n", 353 | "ax.legend()" 354 | ] 355 | }, 356 | { 357 | "cell_type": "code", 358 | "execution_count": null, 359 | "metadata": {}, 360 | "outputs": [], 361 | "source": [ 362 | "fig, ax = plt.subplots(figsize=(15, 10))\n", 363 | "ax.scatter(df.MSE_s_c, df.SSIM_s_c, alpha=.1)" 364 | ] 365 | }, 366 | { 367 | "cell_type": "code", 368 | "execution_count": null, 369 | "metadata": {}, 370 | "outputs": [], 371 | "source": [ 372 | "fig, ax = plt.subplots(figsize=(15, 10))\n", 373 | "ax.scatter(df.PSNR_s_c, df.SSIM_s_c, alpha=.1)" 374 | ] 375 | }, 376 | { 377 | "cell_type": "code", 378 | "execution_count": null, 379 | "metadata": {}, 380 | "outputs": [], 381 | "source": [ 382 | "data_df.rename(columns={'singleton_npz_filename': 'name'}, inplace=True)" 383 | ] 384 | }, 385 | { 386 | "cell_type": "code", 387 | "execution_count": null, 388 | "metadata": {}, 389 | "outputs": [], 390 | "source": [ 391 | "scores_df = pd.merge(df, data_df, on='name')" 392 | ] 393 | }, 394 | { 395 | "cell_type": "code", 396 | "execution_count": null, 397 | "metadata": {}, 398 | "outputs": [], 399 | "source": [ 400 | "scores_df.head()" 401 | ] 402 | } 403 | ], 404 | "metadata": { 405 | "kernelspec": { 406 | "display_name": "Environment (conda_pytorch_p36)", 407 | "language": "python", 408 | "name": "conda_pytorch_p36" 409 | } 410 | }, 411 | "nbformat": 4, 412 | "nbformat_minor": 2 413 | } -------------------------------------------------------------------------------- /notebooks/06-train.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2\n", 11 | "%cd /home/ubuntu/super-resolution/" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": null, 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "import json\n", 21 | "\n", 22 | "import numpy as np\n", 23 | "import pandas as pd\n", 24 | "from fs_s3fs import S3FS\n", 25 | "\n", 26 | "import wandb\n", 27 | "from hrnet.src.train import main\n", 28 | "import torchvision\n", 29 | "from sr.niva_models import TorchUnetv2\n", 30 | "from torch import nn \n", 31 | "import os \n", 32 | "from types import SimpleNamespace\n", 33 | "\n", 34 | "\n", 35 | "\n", 36 | "aws_access_key_id = ''\n", 37 | "aws_secret_access_key = '\n", 38 | "\n", 39 | "filesystem = S3FS(\n", 40 | " bucket_name='',\n", 41 | " aws_access_key_id=aws_access_key_id,\n", 42 | " aws_secret_access_key=aws_secret_access_key, \n", 43 | " region='eu-central-1')\n", 44 | "\n", 45 | "country_norm_df = pd.read_parquet(filesystem.openbin('metadata/s2_norm_per_country.pq'))\n", 46 | "\n", 47 | "data_df = pd.read_parquet(filesystem.openbin('metadata/npz_info_small.pq'))\n", 48 | "data_df.reset_index(inplace=True)\n", 49 | "\n", 50 | "scores_df = pd.read_parquet(filesystem.openbin('baseline-scores-sr/scores-bicubic-32x32-2p5m-hm.pq')\n", 51 | " ).rename(columns={'name': 'singleton_npz_filename'})\n", 52 | "\n", 53 | "data_df = pd.merge(data_df, scores_df, on='singleton_npz_filename')\n", 54 | "data_df['MSE_ratio'] = data_df['MSE_s']/data_df['MSE_s_c']" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": null, 60 | "metadata": {}, 61 | "outputs": [], 62 | "source": [ 63 | "import torch" 64 | ] 65 | }, 66 | { 67 | "cell_type": "markdown", 68 | "metadata": {}, 69 | "source": [ 70 | "### Data setup" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": null, 76 | "metadata": {}, 77 | "outputs": [], 78 | "source": [ 79 | "norm_deimos = {k: v for k, v in np.load(filesystem.openbin('metadata/deimos_min_max_norm.npz')).items()}\n", 80 | "norm_s2 = {k: v for k, v in np.load(filesystem.openbin('metadata/s2_min_max_norm.npz')).items()}" 81 | ] 82 | }, 83 | { 84 | "cell_type": "markdown", 85 | "metadata": {}, 86 | "source": [ 87 | "# Filter data" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": null, 93 | "metadata": {}, 94 | "outputs": [], 95 | "source": [ 96 | "filtered_data = data_df[(data_df['SSIM_s_c'] > .2) &\n", 97 | " (data_df['PSNR_s_c'] > 10) &\n", 98 | " (data_df['MSE_ratio'] < 10) &\n", 99 | " (data_df['is_shadow_v2'] == False) &\n", 100 | " (data_df['countries'] == 'Lithuania') &\n", 101 | " (data_df['num_tstamps'] > 1)]" 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": null, 107 | "metadata": {}, 108 | "outputs": [], 109 | "source": [ 110 | "len(filtered_data)" 111 | ] 112 | }, 113 | { 114 | "cell_type": "markdown", 115 | "metadata": {}, 116 | "source": [ 117 | "### Wandb setup" 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": null, 123 | "metadata": {}, 124 | "outputs": [], 125 | "source": [ 126 | "# ! wandb login " 127 | ] 128 | }, 129 | { 130 | "cell_type": "markdown", 131 | "metadata": { 132 | "lines_to_next_cell": 2 133 | }, 134 | "source": [ 135 | "# Training" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": null, 141 | "metadata": { 142 | "lines_to_next_cell": 2 143 | }, 144 | "outputs": [], 145 | "source": [ 146 | "wandb.init(project='', entity='', config=config)" 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": null, 152 | "metadata": {}, 153 | "outputs": [], 154 | "source": [ 155 | "with open('input/config-local-hrn-pix-shu.json') as f:\n", 156 | " config = json.load(f)" 157 | ] 158 | }, 159 | { 160 | "cell_type": "code", 161 | "execution_count": null, 162 | "metadata": {}, 163 | "outputs": [], 164 | "source": [ 165 | "MODEL_DIR = ''\n", 166 | "with open(os.path.join(MODEL_DIR, 'model_cfg.json'), 'r') as jfile:\n", 167 | " model_cfg = json.load(jfile)\n", 168 | " \n", 169 | "perceptual_model = TorchUnetv2(4, config=SimpleNamespace(**model_cfg))\n", 170 | "perceptual_model.load_state_dict(torch.load(os.path.join(MODEL_DIR, 'model.pth')))\n", 171 | "perceptual_model.eval()\n", 172 | "perceptual_model.to(torch.device('cuda'))" 173 | ] 174 | }, 175 | { 176 | "cell_type": "markdown", 177 | "metadata": {}, 178 | "source": [ 179 | "# Perceptual loss model " 180 | ] 181 | }, 182 | { 183 | "cell_type": "code", 184 | "execution_count": null, 185 | "metadata": {}, 186 | "outputs": [], 187 | "source": [ 188 | "model = main(config,\n", 189 | " filtered_data,\n", 190 | " normalize=True,\n", 191 | " country_norm_df=country_norm_df,\n", 192 | " norm_deimos_npz=norm_deimos,\n", 193 | " norm_s2_npz=norm_s2, perceptual_loss_model=perceptual_model)" 194 | ] 195 | }, 196 | { 197 | "cell_type": "code", 198 | "execution_count": null, 199 | "metadata": {}, 200 | "outputs": [], 201 | "source": [] 202 | } 203 | ], 204 | "metadata": { 205 | "kernelspec": { 206 | "display_name": "Python 3.6.9 64-bit ('venv': virtualenv)", 207 | "language": "python", 208 | "name": "python369jvsc74a57bd008539c228c0b1d46fd3ab380299090bd67be578e8cdd5c516ba9f15efc81c90d" 209 | }, 210 | "language_info": { 211 | "codemirror_mode": { 212 | "name": "ipython", 213 | "version": 3 214 | }, 215 | "file_extension": ".py", 216 | "mimetype": "text/x-python", 217 | "name": "python", 218 | "nbconvert_exporter": "python", 219 | "pygments_lexer": "ipython3", 220 | "version": "3.6.9" 221 | } 222 | }, 223 | "nbformat": 4, 224 | "nbformat_minor": 2 225 | } -------------------------------------------------------------------------------- /notebooks/07-predict.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "lines_to_next_cell": 0 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "%reload_ext autoreload\n", 12 | "%autoreload 2\n", 13 | "%matplotlib inline\n", 14 | "\n", 15 | "import os\n", 16 | "from datetime import datetime\n", 17 | "\n", 18 | "import numpy as np\n", 19 | "import pandas as pd\n", 20 | "import yaml\n", 21 | "from eolearn.core import EOPatch, FeatureType, OverwritePermission\n", 22 | "from eolearn.io import ExportToTiff\n", 23 | "from fs_s3fs import S3FS\n", 24 | "from matplotlib import pyplot as plt\n", 25 | "from skimage.exposure import match_histograms\n", 26 | "from tqdm.auto import tqdm\n", 27 | "\n", 28 | "import torch\n", 29 | "import wandb\n", 30 | "from cv2 import INTER_CUBIC, GaussianBlur, resize\n", 31 | "from hrnet.src.predict import Model\n", 32 | "from hrnet.src.train import resize_batch_images\n", 33 | "from sr.data_loader import EopatchPredictionDataset, ImagesetDataset\n", 34 | "from torch.utils.data import DataLoader" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": null, 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "from sr.metrics import minshift_loss" 44 | ] 45 | }, 46 | { 47 | "cell_type": "markdown", 48 | "metadata": {}, 49 | "source": [ 50 | "## 1.0 Configuration" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": null, 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "# ! wandb login " 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": null, 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [ 68 | "aws_access_key_id = ''\n", 69 | "aws_secret_access_key = ''\n", 70 | "\n", 71 | "filesystem = S3FS(\n", 72 | " bucket_name='',\n", 73 | " aws_access_key_id=aws_access_key_id,\n", 74 | " aws_secret_access_key=aws_secret_access_key, region='eu-central-1')\n", 75 | "\n", 76 | "\n", 77 | "# If 'LOCAL' it will be loaded from local wandb storage, if 'WANDB' from online storage\n", 78 | "MODEL_LOCATION = 'LOCAL'\n", 79 | "\n", 80 | "MODEL_NAME = ''\n", 81 | "MODEL_PREFIX = ''\n", 82 | "MATCHES_S2 = True\n", 83 | "LOCATION = f'wandb/latest-run/files/'" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": null, 89 | "metadata": {}, 90 | "outputs": [], 91 | "source": [ 92 | "EOP_COUNTRIES_PQ = f'eop-countries_overlapped.pq'" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": null, 98 | "metadata": {}, 99 | "outputs": [], 100 | "source": [ 101 | "if not os.path.exists(EOP_COUNTRIES_PQ):\n", 102 | " eops_countries = []\n", 103 | " for eopfname in filesystem.listdir(''):\n", 104 | " eop = EOPatch.load(os.path.join('',\n", 105 | " eopfname), filesystem=filesystem, lazy_loading=True)\n", 106 | " eops_countries.append({'country': 'Lithuania' if str(eop.bbox.crs) == 'EPSG:32634' else 'Cyprus',\n", 107 | " 'eopatch': eopfname})\n", 108 | " pd.DataFrame(eops_countries).to_parquet(\n", 109 | " f'eop-countries_overlapped.pq')" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": null, 115 | "metadata": {}, 116 | "outputs": [], 117 | "source": [ 118 | "checkpoint_filename = 'HRNet.pth'" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": null, 124 | "metadata": {}, 125 | "outputs": [], 126 | "source": [ 127 | "if MODEL_LOCATION == 'WANDB':\n", 128 | " model_checkpoint = wandb.restore(\n", 129 | " checkpoint_filename, run_path=LOCATION, replace=True)\n", 130 | " model_checkpoint = open(checkpoint_filename, 'rb')\n", 131 | " model_config_yaml = yaml.load(wandb.restore(\n", 132 | " 'config.yaml', run_path=LOCATION, replace=True))\n", 133 | "elif MODEL_LOCATION == 'LOCAL':\n", 134 | " model_checkpoint = os.path.join(LOCATION, checkpoint_filename)\n", 135 | " model_config_yaml = yaml.load(open(os.path.join(LOCATION, 'config.yaml')))\n", 136 | "\n", 137 | " assert os.path.isfile(model_checkpoint)" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": null, 143 | "metadata": {}, 144 | "outputs": [], 145 | "source": [ 146 | "config = {k: v['value']\n", 147 | " for k, v in model_config_yaml.items() if 'wandb' not in k}" 148 | ] 149 | }, 150 | { 151 | "cell_type": "code", 152 | "execution_count": null, 153 | "metadata": {}, 154 | "outputs": [], 155 | "source": [ 156 | "country_norm_df = pd.read_parquet(\n", 157 | " filesystem.openbin('metadata/s2_norm_per_country.pq'))\n", 158 | "\n", 159 | "norm_deimos = {k: v for k, v in np.load(\n", 160 | " filesystem.openbin('metadata/deimos_min_max_norm.npz')).items()}\n", 161 | "norm_s2 = {k: v for k, v in np.load(\n", 162 | " filesystem.openbin('metadata/s2_min_max_norm.npz')).items()}\n", 163 | "\n", 164 | "data_df = pd.read_parquet(filesystem.openbin('metadata/npz_info_small.pq'))\n", 165 | "data_df.reset_index(inplace=True)" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": null, 171 | "metadata": {}, 172 | "outputs": [], 173 | "source": [ 174 | "scores_df = pd.read_parquet(filesystem.openbin('scores-bicubic-32x32.pq')).rename(columns={'name': 'singleton_npz_filename'})\n", 175 | "data_df = pd.merge(data_df, scores_df, on='singleton_npz_filename')\n", 176 | "data_df['MSE_ratio'] = data_df['MSE_s']/data_df['MSE_s_c']" 177 | ] 178 | }, 179 | { 180 | "cell_type": "code", 181 | "execution_count": null, 182 | "metadata": {}, 183 | "outputs": [], 184 | "source": [ 185 | "filtered_data = data_df[(data_df['SSIM_s_c'] > .2) &\n", 186 | " (data_df['PSNR_s_c'] > 10) &\n", 187 | " (data_df['MSE_ratio'] < 10) &\n", 188 | " (data_df['is_shadow_v2'] == False) &\n", 189 | " (data_df['countries'] == 'Lithuania') &\n", 190 | " (data_df['num_tstamps'] > 1)]" 191 | ] 192 | }, 193 | { 194 | "cell_type": "code", 195 | "execution_count": null, 196 | "metadata": {}, 197 | "outputs": [], 198 | "source": [ 199 | "model = Model(config)\n", 200 | "model.load_checkpoint(checkpoint_file=model_checkpoint)" 201 | ] 202 | }, 203 | { 204 | "cell_type": "markdown", 205 | "metadata": {}, 206 | "source": [ 207 | "## 1.2 Load data" 208 | ] 209 | }, 210 | { 211 | "cell_type": "code", 212 | "execution_count": null, 213 | "metadata": {}, 214 | "outputs": [], 215 | "source": [ 216 | "test_samples = filtered_data[(filtered_data.train_test_validation == 'validation')].sample(\n", 217 | " 2000).singleton_npz_filename.values\n", 218 | "\n", 219 | "test_dataset = ImagesetDataset(\n", 220 | " imset_dir=config['paths']['prefix'],\n", 221 | " imset_npz_files=test_samples,\n", 222 | " country_norm_df=country_norm_df,\n", 223 | " normalize=True,\n", 224 | " norm_deimos_npz=norm_deimos,\n", 225 | " norm_s2_npz=norm_s2,\n", 226 | " channels_labels=config['training']['channels_labels'],\n", 227 | " channels_feats=config['training']['channels_features'],\n", 228 | " time_first=True,\n", 229 | " n_views=config['training']['n_views'],\n", 230 | " histogram_matching=config['training']['histogram_matching']\n", 231 | ")" 232 | ] 233 | }, 234 | { 235 | "cell_type": "code", 236 | "execution_count": null, 237 | "metadata": {}, 238 | "outputs": [], 239 | "source": [ 240 | "def normalise_bands(eop, bands_name, eop_name, norm_df):\n", 241 | " \"\"\" Normalise bands \"\"\"\n", 242 | " df_means = norm_df[norm_df.eopatch == eop_name].groupby('month').mean()[cols_mean]\n", 243 | " df_std = norm_df[norm_df.eopatch == eop_name].groupby('month').mean()[cols_std]\n", 244 | " \n", 245 | " bands = eop.data[bands_name]\n", 246 | " \n", 247 | " normalised = np.empty(bands.shape, dtype=np.float32)\n", 248 | " \n", 249 | " for nb, (band, ts) in enumerate(zip(bands, eop.timestamp)):\n", 250 | " means = df_means.loc[ts.strftime('%Y-%m')].values\n", 251 | " stds = df_std.loc[ts.strftime('%Y-%m')].values\n", 252 | " \n", 253 | " normalised[nb] = (band - means) / stds\n", 254 | " \n", 255 | " return normalised" 256 | ] 257 | }, 258 | { 259 | "cell_type": "code", 260 | "execution_count": null, 261 | "metadata": {}, 262 | "outputs": [], 263 | "source": [ 264 | "test_dataloader = DataLoader(\n", 265 | " test_dataset,\n", 266 | " batch_size=128,\n", 267 | " shuffle=False,\n", 268 | " num_workers=8,\n", 269 | " pin_memory=True)" 270 | ] 271 | }, 272 | { 273 | "cell_type": "code", 274 | "execution_count": null, 275 | "metadata": {}, 276 | "outputs": [], 277 | "source": [ 278 | "sample = test_dataset[0]" 279 | ] 280 | }, 281 | { 282 | "cell_type": "code", 283 | "execution_count": null, 284 | "metadata": {}, 285 | "outputs": [], 286 | "source": [ 287 | "hr = np.moveaxis(sample['hr'].numpy(), 0, 2)\n", 288 | "\n", 289 | "hr_ = resize(GaussianBlur(hr, ksize=(7, 7), sigmaX=4), None, fx=1/4, fy=1/4)\n", 290 | "\n", 291 | "hr__ = resize(hr_, None, fx=4, fy=4, interpolation=INTER_CUBIC)" 292 | ] 293 | }, 294 | { 295 | "cell_type": "code", 296 | "execution_count": null, 297 | "metadata": {}, 298 | "outputs": [], 299 | "source": [ 300 | "lr = np.moveaxis(\n", 301 | " sample['lr'][np.sum(sample['alphas'].int().numpy())-1].numpy(), 0, 2)" 302 | ] 303 | }, 304 | { 305 | "cell_type": "code", 306 | "execution_count": null, 307 | "metadata": {}, 308 | "outputs": [], 309 | "source": [ 310 | "fig, axs = plt.subplots(ncols=4, figsize=(15, 7.5))\n", 311 | "axs[0].imshow(hr[..., [2, 1, 0]])\n", 312 | "axs[1].imshow(hr_[..., [2, 1, 0]])\n", 313 | "axs[2].imshow(lr[..., [2, 1, 0]])\n", 314 | "axs[3].imshow(hr__[..., [2, 1, 0]])" 315 | ] 316 | }, 317 | { 318 | "cell_type": "code", 319 | "execution_count": null, 320 | "metadata": {}, 321 | "outputs": [], 322 | "source": [ 323 | "lr_ = match_histograms(lr, hr_, multichannel=True)\n", 324 | "lr__ = resize(lr_, None, fx=4, fy=4, interpolation=INTER_CUBIC)" 325 | ] 326 | }, 327 | { 328 | "cell_type": "code", 329 | "execution_count": null, 330 | "metadata": {}, 331 | "outputs": [], 332 | "source": [ 333 | "fig, axs = plt.subplots(ncols=4, figsize=(15, 7.5))\n", 334 | "axs[0].imshow(hr[..., [2, 1, 0]])\n", 335 | "axs[1].imshow(hr_[..., [2, 1, 0]])\n", 336 | "axs[2].imshow(lr_[..., [2, 1, 0]])\n", 337 | "axs[3].imshow(lr__[..., [2, 1, 0]])" 338 | ] 339 | }, 340 | { 341 | "cell_type": "code", 342 | "execution_count": null, 343 | "metadata": {}, 344 | "outputs": [], 345 | "source": [ 346 | "ssims_bi_de, psnrs_bi_de = [], []\n", 347 | "ssims_bi_s2, psnrs_bi_s2 = [], []\n", 348 | "ssims_sr, psnrs_sr = [], []\n", 349 | "\n", 350 | "for sample in tqdm(test_dataloader):\n", 351 | " sr = torch.from_numpy(model(sample))\n", 352 | " alphas = sample['alphas'].float()\n", 353 | " lrs = sample['lr'][np.arange(len(alphas)),\n", 354 | " torch.sum(alphas, dim=1, dtype=torch.int64) - 1]\n", 355 | " hr = sample['hr'].float()\n", 356 | "\n", 357 | " lrs_hm = torch.tensor([match_histograms(np.moveaxis(lri.numpy(), 0, 2),\n", 358 | " np.moveaxis(hri.numpy(), 0, 2),\n", 359 | " multichannel=True)\n", 360 | " for (lri, hri) in zip(lrs, hr)])\n", 361 | " \n", 362 | " lrs_hm = lrs_hm.permute([0, 3, 1, 2])\n", 363 | "\n", 364 | " baseline_s2 = resize_batch_images(lrs_hm, fx=4, fy=4).float()\n", 365 | "\n", 366 | " baseline_de = torch.tensor([resize(resize(GaussianBlur(np.moveaxis(hr_.numpy(), 0, 2),\n", 367 | " ksize=(7, 7),\n", 368 | " sigmaX=4), None, fx=1/4, fy=1/4),\n", 369 | " None, fx=4, fy=4, interpolation=INTER_CUBIC) for hr_ in hr])\n", 370 | " baseline_de = baseline_de.permute([0, 3, 1, 2])\n", 371 | "\n", 372 | " ssims_sr.append(minshift_loss(hr, sr, metric='SSIM', apply_correction=False)[0])\n", 373 | " ssims_bi_de.append(minshift_loss(hr, baseline_de, metric='SSIM', apply_correction=False)[0])\n", 374 | " ssims_bi_s2.append(minshift_loss(hr, baseline_s2, metric='SSIM', apply_correction=False)[0])\n", 375 | "\n", 376 | " psnrs_sr.append(minshift_loss(hr, sr, metric='PSNR', apply_correction=False)[0])\n", 377 | " psnrs_bi_de.append(minshift_loss(hr, baseline_de, metric='PSNR', apply_correction=False)[0])\n", 378 | " psnrs_bi_s2.append(minshift_loss(hr, baseline_s2, metric='PSNR', apply_correction=False)[0])" 379 | ] 380 | }, 381 | { 382 | "cell_type": "code", 383 | "execution_count": null, 384 | "metadata": {}, 385 | "outputs": [], 386 | "source": [ 387 | "ssim_bi_de = np.array([jj for item in ssims_bi_de for jj in item.numpy()])\n", 388 | "ssim_bi_s2 = np.array([jj for item in ssims_bi_s2 for jj in item.numpy()])\n", 389 | "ssim_sr = np.array([jj for item in ssims_sr for jj in item.numpy()])\n", 390 | "\n", 391 | "psnr_bi_de = np.array([jj for item in psnrs_bi_de for jj in item.numpy()])\n", 392 | "psnr_bi_s2 = np.array([jj for item in psnrs_bi_s2 for jj in item.numpy()])\n", 393 | "psnr_sr = np.array([jj for item in psnrs_sr for jj in item.numpy()])" 394 | ] 395 | } 396 | ], 397 | "metadata": { 398 | "kernelspec": { 399 | "display_name": "Environment (conda_pytorch_p36)", 400 | "language": "python", 401 | "name": "conda_pytorch_p36" 402 | }, 403 | "language_info": { 404 | "codemirror_mode": { 405 | "name": "ipython", 406 | "version": 3 407 | }, 408 | "file_extension": ".py", 409 | "mimetype": "text/x-python", 410 | "name": "python", 411 | "nbconvert_exporter": "python", 412 | "pygments_lexer": "ipython3", 413 | "version": "3.6.7" 414 | } 415 | }, 416 | "nbformat": 4, 417 | "nbformat_minor": 2 418 | } -------------------------------------------------------------------------------- /notebooks/07b-predict-eopatches.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%reload_ext autoreload\n", 10 | "%autoreload 2\n", 11 | "%matplotlib inline\n", 12 | "\n", 13 | "import os\n", 14 | "\n", 15 | "import numpy as np\n", 16 | "import pandas as pd\n", 17 | "import yaml\n", 18 | "from eolearn.core import EOPatch, OverwritePermission\n", 19 | "from fs_s3fs import S3FS\n", 20 | "from matplotlib import pyplot as plt\n", 21 | "from tqdm.auto import tqdm\n", 22 | "\n", 23 | "import torch\n", 24 | "import wandb\n", 25 | "from hrnet.src.predict import Model" 26 | ] 27 | }, 28 | { 29 | "cell_type": "markdown", 30 | "metadata": {}, 31 | "source": [ 32 | "## 1.0 Configuration" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": null, 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "# ! wandb login " 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": null, 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "aws_access_key_id = ''\n", 51 | "aws_secret_access_key = ''\n", 52 | "\n", 53 | "filesystem = S3FS(\n", 54 | " bucket_name='',\n", 55 | " aws_access_key_id=aws_access_key_id,\n", 56 | " aws_secret_access_key=aws_secret_access_key, region='eu-central-1')\n", 57 | "\n", 58 | "\n", 59 | "MODEL_LOCATION = 'LOCAL' # If 'LOCAL' it will be loaded from local wandb storage, if 'WANDB' from online storage\n", 60 | "MODEL_NAME = ''\n", 61 | "MODEL_PREFIX = ''\n", 62 | "MATCHES_S2 = True\n", 63 | "LOCATION = f'wandb/latest-run/files/'" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": null, 69 | "metadata": {}, 70 | "outputs": [], 71 | "source": [ 72 | "EOP_COUNTRIES_PQ = f'{DIONE_DIR}/eop-countries_overlapped.pq'" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": null, 78 | "metadata": {}, 79 | "outputs": [], 80 | "source": [ 81 | "if not os.path.exists(EOP_COUNTRIES_PQ):\n", 82 | " eops_countries = []\n", 83 | " for eopfname in filesystem.listdir(''):\n", 84 | " eop = EOPatch.load(os.path.join('',\n", 85 | " eopfname), filesystem=filesystem, lazy_loading=True)\n", 86 | " eops_countries.append({'country': 'Lithuania' if str(eop.bbox.crs) == 'EPSG:32634' else 'Cyprus',\n", 87 | " 'eopatch': eopfname})\n", 88 | " pd.DataFrame(eops_countries).to_parquet(f'{DIONE_DIR}/eop-countries_overlapped.pq')" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": null, 94 | "metadata": {}, 95 | "outputs": [], 96 | "source": [ 97 | "checkpoint_filename = 'HRNet.pth'" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": null, 103 | "metadata": {}, 104 | "outputs": [], 105 | "source": [ 106 | "if MODEL_LOCATION == 'WANDB':\n", 107 | " model_checkpoint = wandb.restore(checkpoint_filename, run_path=LOCATION, replace=True)\n", 108 | " model_checkpoint = open(checkpoint_filename, 'rb')\n", 109 | " model_config_yaml = yaml.load(wandb.restore('config.yaml', run_path=LOCATION, replace=True))\n", 110 | "elif MODEL_LOCATION == 'LOCAL':\n", 111 | " model_checkpoint = os.path.join(LOCATION, checkpoint_filename)\n", 112 | " model_config_yaml = yaml.load(open(os.path.join(LOCATION, 'config.yaml')))\n", 113 | "\n", 114 | " assert os.path.isfile(model_checkpoint)" 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": null, 120 | "metadata": {}, 121 | "outputs": [], 122 | "source": [ 123 | "config = {k: v['value'] for k, v in model_config_yaml.items() if 'wandb' not in k}" 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": null, 129 | "metadata": { 130 | "lines_to_next_cell": 2 131 | }, 132 | "outputs": [], 133 | "source": [ 134 | "country_norm_df = pd.read_parquet(filesystem.openbin('metadata/s2_norm_per_country.pq'))\n", 135 | "\n", 136 | "norm_deimos = {k: v for k, v in np.load(filesystem.openbin('metadata/deimos_min_max_norm.npz')).items()}\n", 137 | "norm_s2 = {k: v for k, v in np.load(filesystem.openbin('metadata/s2_min_max_norm.npz')).items()}\n", 138 | "\n", 139 | "data_df = pd.read_parquet(filesystem.openbin('metadata/npz_info_small.pq'))\n", 140 | "data_df.reset_index(inplace=True)" 141 | ] 142 | }, 143 | { 144 | "cell_type": "code", 145 | "execution_count": null, 146 | "metadata": {}, 147 | "outputs": [], 148 | "source": [ 149 | "model = Model(config)\n", 150 | "model.load_checkpoint(checkpoint_file=model_checkpoint)" 151 | ] 152 | }, 153 | { 154 | "cell_type": "markdown", 155 | "metadata": {}, 156 | "source": [ 157 | "# Predict on EOPatches" 158 | ] 159 | }, 160 | { 161 | "cell_type": "code", 162 | "execution_count": null, 163 | "metadata": {}, 164 | "outputs": [], 165 | "source": [ 166 | "def _filter_cloudy_s2(eop, max_cc):\n", 167 | " idxs = []\n", 168 | " for i, _ in enumerate(eop.timestamp):\n", 169 | " if (eop.mask['CLM'][i, ...].mean() <= max_cc) and (eop.mask['IS_DATA'].mean() == 1):\n", 170 | " idxs.append(i)\n", 171 | " eop.data['BANDS'] = eop.data['BANDS'][idxs, ...]\n", 172 | " eop.data['CLP'] = eop.data['CLP'][idxs, ...]\n", 173 | " eop.mask['CLM'] = eop.mask['CLM'][idxs, ...]\n", 174 | " eop.mask['IS_DATA'] = eop.mask['IS_DATA'][idxs, ...]\n", 175 | " eop.scalar['NORM_FACTORS'] = eop.scalar['NORM_FACTORS'][idxs, ...]\n", 176 | "\n", 177 | " eop.timestamp = list(np.array(eop.timestamp)[idxs])\n", 178 | " return eop\n", 179 | "\n", 180 | "\n", 181 | "def _timestamps_within_date(timestamps, start_date, end_date):\n", 182 | " return [i for i, ts in enumerate(timestamps) if ts >= start_date and ts < end_date]\n", 183 | "\n", 184 | "\n", 185 | "def predict_sr_images(eopatch_name: str,\n", 186 | " model: Model,\n", 187 | " model_prefix: str,\n", 188 | " scale_factor: int = 4,\n", 189 | " filesystem: S3FS = None,\n", 190 | " normalize: bool = True,\n", 191 | " country_norm_df: pd.DataFrame = None,\n", 192 | " norm_s2_npz: np.lib.npyio.NpzFile = None,\n", 193 | " max_cc: float = 0.05,\n", 194 | " n_views: int = 8,\n", 195 | " padding: str = 'zeros'):\n", 196 | " \"\"\" Predict an SR image at the EOPatch level for all timeframes available \"\"\"\n", 197 | " assert padding in ['zeros', 'repeat']\n", 198 | "\n", 199 | " eopatch = EOPatch.load(eopatch_name,\n", 200 | " filesystem=filesystem,\n", 201 | " lazy_loading=True)\n", 202 | " noncloudy = _filter_cloudy_s2(eopatch, max_cc=max_cc)\n", 203 | "# ts_idxs = _timestamps_within_date(noncloudy.timestamp, start_date, end_date)\n", 204 | " features = noncloudy.data['BANDS'] / 10000\n", 205 | "# filtered_ts = [eopatch.timestamp[tsi] for tsi in ts_idxs]\n", 206 | "\n", 207 | " if normalize:\n", 208 | " country = 'Lithuania' if str(eopatch.bbox.crs) == 'EPSG:32634' else 'Cyprus' # WARNING EXTREMLY HACKY HACKY\n", 209 | " country_stats = country_norm_df[country_norm_df.country == str(country)]\n", 210 | "\n", 211 | " norm_median = country_stats[['median_0', 'median_1', 'median_2', 'median_3']].values\n", 212 | " norm_std = country_stats[['std_0', 'std_1', 'std_2', 'std_3']].values\n", 213 | "\n", 214 | " features = (features - norm_median) / norm_std\n", 215 | "\n", 216 | " s2_p1 = norm_s2_npz['p1']\n", 217 | " s2_p99 = norm_s2_npz['p99']\n", 218 | "\n", 219 | " features = (features - s2_p1) / (s2_p99 - s2_p1)\n", 220 | "\n", 221 | " n_frames, height, width, nch = features.shape\n", 222 | " super_resolved = np.empty((n_frames,\n", 223 | " height*scale_factor,\n", 224 | " width*scale_factor,\n", 225 | " nch), dtype=np.uint16)\n", 226 | " actual_n_views = np.array([np.min([n_views, nfr+1])\n", 227 | " for nfr in np.arange(n_frames)]).astype(np.uint8)\n", 228 | "\n", 229 | " for nfr in np.arange(n_frames):\n", 230 | " inarr = None\n", 231 | " alphas = None\n", 232 | " if nfr < n_views:\n", 233 | " inarr = np.concatenate([features[:nfr+1],\n", 234 | " np.zeros((n_views-nfr-1, height, width, nch),\n", 235 | " dtype=np.float32)],\n", 236 | " axis=0)\n", 237 | " alphas = np.zeros(n_views, dtype=np.uint8)\n", 238 | " alphas[:nfr+1] = 1\n", 239 | " else:\n", 240 | " inarr = features[nfr-n_views+1:nfr+1]\n", 241 | " alphas = np.ones(n_views, dtype=np.uint8)\n", 242 | "\n", 243 | " # CxTxHxW\n", 244 | " inarr = np.moveaxis(inarr, -1, 1)\n", 245 | "\n", 246 | "# np.testing.assert_array_equal(inarr[nfr if nfr < n_views else -1], features[nfr])\n", 247 | "\n", 248 | " sr = model({'lr': torch.from_numpy(inarr.copy()),\n", 249 | " 'alphas': torch.from_numpy(alphas),\n", 250 | " 'name': eopatch_name})\n", 251 | "\n", 252 | " # channels back to last\n", 253 | " sr = np.moveaxis(sr.squeeze(), 0, 2)\n", 254 | "\n", 255 | " # denormalise\n", 256 | " sr = (sr * (s2_p99 - s2_p1) + s2_p1) * norm_std + norm_median\n", 257 | "\n", 258 | " super_resolved[nfr] = (np.clip(sr, 0, 3)*10000).astype(np.uint16)\n", 259 | "\n", 260 | " eop_sr = EOPatch(bbox=eopatch.bbox, timestamp=noncloudy.timestamp)\n", 261 | " eop_sr.data[f'SR-{model_prefix.upper()}'] = super_resolved\n", 262 | " eop_sr.data['S2'] = noncloudy.data['BANDS'].astype(np.uint16)\n", 263 | " eop_sr.scalar['N_VIEWS'] = actual_n_views[..., np.newaxis]\n", 264 | "\n", 265 | " return eop_sr" 266 | ] 267 | }, 268 | { 269 | "cell_type": "code", 270 | "execution_count": null, 271 | "metadata": {}, 272 | "outputs": [], 273 | "source": [ 274 | "eops_folder = ''\n", 275 | "deimos_eops_folder = ''" 276 | ] 277 | }, 278 | { 279 | "cell_type": "code", 280 | "execution_count": null, 281 | "metadata": {}, 282 | "outputs": [], 283 | "source": [ 284 | "eop_countries = pd.read_parquet(EOP_COUNTRIES_PQ)" 285 | ] 286 | }, 287 | { 288 | "cell_type": "code", 289 | "execution_count": null, 290 | "metadata": {}, 291 | "outputs": [], 292 | "source": [ 293 | "eopatch_names = eop_countries[eop_countries.country == 'Lithuania'].eopatch.unique()" 294 | ] 295 | }, 296 | { 297 | "cell_type": "code", 298 | "execution_count": null, 299 | "metadata": {}, 300 | "outputs": [], 301 | "source": [ 302 | "eop_sr = predict_sr_images(f'{eops_folder}/{eopatch_names[0]}',\n", 303 | " model,\n", 304 | " MODEL_PREFIX,\n", 305 | " scale_factor=4,\n", 306 | " country_norm_df=country_norm_df,\n", 307 | " filesystem=filesystem,\n", 308 | " normalize=True,\n", 309 | " norm_s2_npz=norm_s2,\n", 310 | " max_cc=0.05,\n", 311 | " n_views=config['training']['n_views'])" 312 | ] 313 | }, 314 | { 315 | "cell_type": "code", 316 | "execution_count": null, 317 | "metadata": {}, 318 | "outputs": [], 319 | "source": [ 320 | "eop_sr" 321 | ] 322 | }, 323 | { 324 | "cell_type": "code", 325 | "execution_count": null, 326 | "metadata": { 327 | "lines_to_next_cell": 2 328 | }, 329 | "outputs": [], 330 | "source": [ 331 | "fig, axs = plt.subplots(ncols=2, nrows=17, figsize=(15, 17*7.5))\n", 332 | "\n", 333 | "for ni, (s2, sr) in enumerate(zip(eop_sr.data['S2'], eop_sr.data[f'SR-{MODEL_PREFIX.upper()}'])):\n", 334 | " axs[ni][0].imshow(2.5*s2[..., [2, 1, 0]]/10000)\n", 335 | " axs[ni][1].imshow(2.5*sr[..., [2, 1, 0]]/10000)\n", 336 | " axs[ni][0].set_title(f'S2 - {eop_sr.timestamp[ni]}')\n", 337 | " axs[ni][1].set_title(f'SR - {eop_sr.scalar[\"N_VIEWS\"][ni][0]} actual views')\n", 338 | "\n", 339 | "fig.tight_layout()" 340 | ] 341 | }, 342 | { 343 | "cell_type": "code", 344 | "execution_count": null, 345 | "metadata": {}, 346 | "outputs": [], 347 | "source": [ 348 | "eops_sr_folder = f'eopatches-{MODEL_PREFIX}/'" 349 | ] 350 | }, 351 | { 352 | "cell_type": "code", 353 | "execution_count": null, 354 | "metadata": {}, 355 | "outputs": [], 356 | "source": [ 357 | "eops_sr_folder" 358 | ] 359 | }, 360 | { 361 | "cell_type": "code", 362 | "execution_count": null, 363 | "metadata": {}, 364 | "outputs": [], 365 | "source": [ 366 | "for eopatch_name in tqdm(eopatch_names):\n", 367 | " try:\n", 368 | " eop_sr = predict_sr_images(f'{eops_folder}/{eopatch_name}',\n", 369 | " model,\n", 370 | " MODEL_PREFIX,\n", 371 | " scale_factor=4,\n", 372 | " country_norm_df=country_norm_df,\n", 373 | " filesystem=filesystem,\n", 374 | " normalize=True,\n", 375 | " norm_s2_npz=norm_s2,\n", 376 | " max_cc=0.05,\n", 377 | " n_views=config['training']['n_views'])\n", 378 | " eop_sr.save(f'{eops_sr_folder}/{eopatch_name}',\n", 379 | " filesystem=filesystem,\n", 380 | " overwrite_permission=OverwritePermission.OVERWRITE_FEATURES)\n", 381 | " del eop_sr\n", 382 | " except RuntimeError:\n", 383 | " print(f'Error in {eopatch_name}')" 384 | ] 385 | }, 386 | { 387 | "cell_type": "code", 388 | "execution_count": null, 389 | "metadata": {}, 390 | "outputs": [], 391 | "source": [] 392 | } 393 | ], 394 | "metadata": { 395 | "kernelspec": { 396 | "display_name": "Environment (conda_pytorch_p36)", 397 | "language": "python", 398 | "name": "conda_pytorch_p36" 399 | }, 400 | "language_info": { 401 | "codemirror_mode": { 402 | "name": "ipython", 403 | "version": 3 404 | }, 405 | "file_extension": ".py", 406 | "mimetype": "text/x-python", 407 | "name": "python", 408 | "nbconvert_exporter": "python", 409 | "pygments_lexer": "ipython3", 410 | "version": "3.6.7" 411 | } 412 | }, 413 | "nbformat": 4, 414 | "nbformat_minor": 2 415 | } -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | pytest>=4.0.0 2 | pytest-cov 3 | codecov 4 | pylint -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | fs 2 | dataclasses 3 | numpy 4 | pandas 5 | sentinelhub 6 | eo-learn-core 7 | opencv-python 8 | matplotlib 9 | torch 10 | torchvision 11 | wandb 12 | tqdm 13 | piqa==1.1.0 14 | tensorboardX 15 | pyarrow 16 | scikit-image -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | from setuptools import setup, find_packages 4 | 5 | 6 | def parse_requirements(file): 7 | with open(os.path.join(os.path.dirname(__file__), file)) as req_file: 8 | return [line.strip() for line in req_file if '/' not in line] 9 | 10 | 11 | def get_version(): 12 | with open(os.path.join(os.path.dirname(__file__), 'sr', '__init__.py')) as file: 13 | return re.findall("__version__ = \'(.*)\'", file.read())[0] 14 | 15 | 16 | setup( 17 | name='sr', 18 | python_requires='>=3.6', 19 | version=get_version(), 20 | description='EO Research - Super Resolution', 21 | url='https://github.com/sentinel-hub/multi-temporal-super-resolution', 22 | author='Sinergise EO research team', 23 | author_email='eoresearch@sinergise.com', 24 | packages=find_packages(), 25 | install_requires=parse_requirements('requirements.txt'), 26 | extras_require={ 27 | 'DEV': parse_requirements('requirements-dev.txt') 28 | }, 29 | zip_safe=False 30 | ) -------------------------------------------------------------------------------- /sr/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.0.1' 2 | -------------------------------------------------------------------------------- /sr/metrics.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Union 2 | from functools import partial 3 | 4 | import torch 5 | import tensorflow as tf 6 | 7 | from piqa.ssim import SSIM 8 | from piqa.psnr import mse, psnr 9 | import numpy as np 10 | import gc 11 | 12 | def mixed_metric(hrs: torch.Tensor, srs: torch.Tensor, weight: float = 0.1) -> torch.Tensor: 13 | """ Computes a metric defined as MSE + weight * (-SSIM + 1) """ 14 | return mse(hrs, srs) + weight * (-1 * ssim(hrs, srs) + 1) 15 | 16 | 17 | def ssim(hrs: torch.Tensor, srs: torch.Tensor) -> torch.Tensor: 18 | _, n_channels, _, _ = hrs.shape 19 | return SSIM(n_channels=n_channels, reduction='none').to(hrs.device)(hrs, srs) 20 | 21 | 22 | METRICS = dict(MSE=mse, 23 | PSNR=psnr, 24 | SSIM=ssim, 25 | MAE=lambda x, y: torch.mean(torch.abs(x - y), dim=(1, 2, 3)), 26 | MIXED=partial(mixed_metric, weight=.1)) 27 | 28 | 29 | def calculate_metrics(hrs: torch.Tensor, srs: torch.Tensor, metrics: Union[str, Tuple[str]] = ('SSIM', 'PSNR', 'MSE'), 30 | apply_correction: bool = False) -> Union[torch.Tensor, dict]: 31 | """ 32 | Computes L1/MSE/SSIM/PSNR loss for each instance in a batch. 33 | 34 | :param hrs: tensor (B, C, H, W), high-res images 35 | :param srs: tensor (B, C, H, W), super resolved images 36 | :param metrics: name of metrics to compute. If a str, a single tensor is returned, if a list a dictionary of 37 | metrics is returned 38 | :param apply_correction: whether to apply brightness correction or not 39 | :returns scores: tensor (B), metric for each super resolved image. 40 | """ 41 | metrics_check = [metrics] if isinstance(metrics, str) else metrics 42 | assert set(metrics_check) == set(metrics_check).intersection(METRICS), \ 43 | f'The only supported metrics are {list(METRICS.keys())}' 44 | 45 | batch, channels, _, _ = srs.shape 46 | 47 | if apply_correction: 48 | bias = torch.mean(hrs-srs, dim=(2, 3)) 49 | srs = srs + bias.reshape(batch, channels, 1, 1) 50 | 51 | if isinstance(metrics, str): 52 | 53 | return METRICS[metrics](hrs, srs) 54 | 55 | scores = {} 56 | 57 | for metric in metrics: 58 | scores[metric] = METRICS[metric](hrs, srs) 59 | 60 | return scores 61 | 62 | 63 | def minshift_loss(hrs: torch.Tensor, srs: torch.Tensor, metric: str, 64 | shifts: int = 5, apply_correction: bool = False,) -> Tuple[torch.Tensor, Tuple[int, int, int, int]]: 65 | """ 66 | Computes a metric over shifted versions on the high-resolution image. The minimum over the shifts is returned. 67 | 68 | :param hrs: tensor (B, C, H, W), high-res images 69 | :param srs: tensor (B, C, H, W), super resolved images 70 | :param metric: name of metric to compute 71 | :param shifts: size of shifts in x and y dimensions to consider. All possible (i,j) positions are considered 72 | :param apply_correction: whether to apply brightness correction or not 73 | :returns scores: tensor (B), metric for each super resolved image. 74 | :returns indices: tuple, indices with best alignment between HR and SR image 75 | """ 76 | _, _, h, w = srs.shape 77 | 78 | border = shifts // 2 79 | h_, w_ = h - 2*border, w - 2*border 80 | 81 | srs_mid = srs[..., border:-border, border:-border] 82 | scores = [] 83 | offsets = [] 84 | 85 | for i in range(shifts): 86 | for j in range(shifts): 87 | hrs_shift = hrs[..., i:i + h_, j:j + w_] 88 | 89 | score = calculate_metrics(hrs_shift, srs_mid, metrics=metric, apply_correction=apply_correction) 90 | scores.append(score) 91 | offsets.append((i, i+h_, j, j+w_)) 92 | 93 | scores = torch.stack(scores) 94 | scores, indices = torch.min(scores, dim=0) if metric in ('MAE', 'MSE') else torch.max(scores, dim=0) 95 | 96 | indices = torch.Tensor(offsets)[indices] 97 | 98 | return scores, indices 99 | 100 | 101 | def compute_perceptual_loss(norm_de: torch.Tensor, norm_sr: torch.Tensor, model) -> Tuple[torch.Tensor, torch.Tensor]: 102 | """ Compute perceptual losses """ 103 | feat_loss, style_loss = [], [] 104 | 105 | layer_names = {x.split('.')[0] for x in model.state_dict().keys() if 'conv' in x} 106 | 107 | 108 | for layer_name in layer_names: 109 | 110 | activation = {} 111 | def get_activation(name): 112 | def hook(model, input, output): 113 | activation[name] = output.detach() 114 | return hook 115 | 116 | handle = getattr(model, layer_name).register_forward_hook(get_activation(layer_name)) 117 | 118 | _ = model(norm_de) 119 | interm_de = activation[layer_name] 120 | 121 | _ = model(norm_sr) 122 | interm_sr = activation[layer_name] 123 | 124 | handle.remove() 125 | 126 | feat_loss.append(torch.mean((interm_sr-interm_de)**2)) 127 | 128 | nbatch, nfeat, height, width = interm_de.shape 129 | 130 | rav_de = interm_de.reshape(nbatch*height*width, nfeat) 131 | rav_sr = interm_sr.reshape(nbatch*height*width, nfeat) 132 | 133 | gram_de = torch.matmul(rav_de.T, rav_de)/(nbatch*height*width) 134 | gram_sr = torch.matmul(rav_sr.T, rav_sr)/(nbatch*height*width) 135 | 136 | style_loss.append(torch.linalg.norm(gram_sr-gram_de)) 137 | del rav_de, rav_sr, gram_de, gram_sr, interm_de, interm_sr, get_activation, handle 138 | 139 | return torch.stack(feat_loss), torch.stack(style_loss) -------------------------------------------------------------------------------- /sr/niva_models.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | 5 | class TorchUnetv2(nn.Module): 6 | 7 | def __init__(self, in_channels, config): 8 | """ 9 | Args: 10 | config : dict, configuration file 11 | """ 12 | 13 | super(TorchUnetv2, self).__init__() 14 | #TODO: do this recursively as done for TF 15 | self.conv_1 = nn.Sequential(nn.Conv2d(in_channels, 16 | config.features_root, 17 | config.conv_size, 18 | stride=config.conv_stride, 19 | padding=config.conv_size//2), 20 | nn.ReLU(), nn.Dropout(p=1-config.keep_prob), 21 | nn.Conv2d(config.features_root, 22 | config.features_root, 23 | config.conv_size, 24 | stride=config.conv_stride, 25 | padding=config.conv_size//2), 26 | nn.ReLU(), nn.Dropout(p=1-config.keep_prob)) 27 | 28 | # acts as pool 29 | self.conv_pool_1 = nn.Sequential(nn.Conv2d(config.features_root, 30 | config.features_root, 31 | config.pool_size, 32 | stride=config.pool_stride, 33 | padding=0), 34 | nn.ReLU(), nn.Dropout(p=1-config.keep_prob)) 35 | 36 | self.conv_2 = nn.Sequential(nn.Conv2d(config.features_root, 37 | 2*config.features_root, 38 | config.conv_size, 39 | stride=config.conv_stride, 40 | padding=config.conv_size//2), 41 | nn.ReLU(), nn.Dropout(p=1-config.keep_prob), 42 | nn.Conv2d(2*config.features_root, 43 | 2*config.features_root, 44 | config.conv_size, 45 | stride=config.conv_stride, 46 | padding=config.conv_size//2), 47 | nn.ReLU(), nn.Dropout(p=1-config.keep_prob)) 48 | 49 | # acts as pool 50 | self.conv_pool_2 = nn.Sequential(nn.Conv2d(2*config.features_root, 51 | 2*config.features_root, 52 | config.pool_size, 53 | stride=config.pool_stride, 54 | padding=0), 55 | nn.ReLU(), nn.Dropout(p=1-config.keep_prob)) 56 | 57 | self.conv_3 = nn.Sequential(nn.Conv2d(2*config.features_root, 58 | 4*config.features_root, 59 | config.conv_size, 60 | stride=config.conv_stride, 61 | padding=config.conv_size//2), 62 | nn.ReLU(), nn.Dropout(p=1-config.keep_prob), 63 | nn.Conv2d(4*config.features_root, 64 | 4*config.features_root, 65 | config.conv_size, 66 | stride=config.conv_stride, 67 | padding=config.conv_size//2), 68 | nn.ReLU(), nn.Dropout(p=1-config.keep_prob)) 69 | 70 | # acts as pool 71 | self.conv_pool_3 = nn.Sequential(nn.Conv2d(4*config.features_root, 72 | 4*config.features_root, 73 | config.pool_size, 74 | stride=config.pool_stride, 75 | padding=0), 76 | nn.ReLU(), nn.Dropout(p=1-config.keep_prob)) 77 | 78 | self.conv_4 = nn.Sequential(nn.Conv2d(4*config.features_root, 79 | 8*config.features_root, 80 | config.conv_size, 81 | stride=config.conv_stride, 82 | padding=config.conv_size//2), 83 | nn.ReLU(), nn.Dropout(p=1-config.keep_prob), 84 | nn.Conv2d(8*config.features_root, 85 | 8*config.features_root, 86 | config.conv_size, 87 | stride=config.conv_stride, 88 | padding=config.conv_size//2), 89 | nn.ReLU(), nn.Dropout(p=1-config.keep_prob)) 90 | 91 | self.deconv_1 = nn.Sequential(nn.ConvTranspose2d(8*config.features_root, 92 | 4*config.features_root, 93 | config.deconv_size, 94 | stride=config.deconv_size), 95 | nn.ReLU()) 96 | 97 | self.conv_5 = nn.Sequential(nn.Conv2d(8*config.features_root, 98 | 4*config.features_root, 99 | config.conv_size, 100 | stride=config.conv_stride, 101 | padding=config.conv_size//2), 102 | nn.ReLU(), nn.Dropout(p=1-config.keep_prob), 103 | nn.Conv2d(4*config.features_root, 104 | 4*config.features_root, 105 | config.conv_size, 106 | stride=config.conv_stride, 107 | padding=config.conv_size//2), 108 | nn.ReLU(), nn.Dropout(p=1-config.keep_prob)) 109 | 110 | self.deconv_2 = nn.Sequential(nn.ConvTranspose2d(4*config.features_root, 111 | 2*config.features_root, 112 | config.deconv_size, 113 | stride=config.deconv_size), 114 | nn.ReLU()) 115 | 116 | self.conv_6 = nn.Sequential(nn.Conv2d(4*config.features_root, 117 | 2*config.features_root, 118 | config.conv_size, 119 | stride=config.conv_stride, 120 | padding=config.conv_size//2), 121 | nn.ReLU(), nn.Dropout(p=1-config.keep_prob), 122 | nn.Conv2d(2*config.features_root, 123 | 2*config.features_root, 124 | config.conv_size, 125 | stride=config.conv_stride, 126 | padding=config.conv_size//2), 127 | nn.ReLU(), nn.Dropout(p=1-config.keep_prob)) 128 | 129 | self.deconv_3 = nn.Sequential(nn.ConvTranspose2d(2*config.features_root, 130 | config.features_root, 131 | config.deconv_size, 132 | stride=config.deconv_size), 133 | nn.ReLU()) 134 | 135 | 136 | conv_dist_1 = nn.Sequential(nn.Conv2d(2*config.features_root, 137 | config.features_root, 138 | config.conv_size, 139 | stride=config.conv_stride, 140 | padding=config.conv_size//2), 141 | nn.ReLU(), nn.Dropout(p=1-config.keep_prob), 142 | nn.Conv2d(config.features_root, 143 | config.features_root, 144 | config.conv_size, 145 | stride=config.conv_stride, 146 | padding=config.conv_size//2), 147 | nn.ReLU(), nn.Dropout(p=1-config.keep_prob)) 148 | 149 | conv_dist_2 = nn.Sequential(nn.Conv2d(config.features_root, 150 | config.n_classes, 151 | 1), 152 | nn.Softmax(dim=1)) 153 | self.distance = nn.Sequential(conv_dist_1, conv_dist_2) 154 | 155 | conv_bound_1 = nn.Sequential(nn.Conv2d(2*config.features_root+config.n_classes, 156 | config.features_root, 157 | config.conv_size, 158 | stride=config.conv_stride, 159 | padding=config.conv_size//2), 160 | nn.ReLU(), nn.Dropout(p=1-config.keep_prob)) 161 | 162 | conv_bound_2 = nn.Sequential(nn.Conv2d(config.features_root, 163 | config.n_classes, 164 | 1), 165 | nn.Softmax(dim=1)) 166 | self.boundary = nn.Sequential(conv_bound_1, conv_bound_2) 167 | 168 | conv_extent_1 = nn.Sequential(nn.Conv2d(2*config.features_root+2*config.n_classes, 169 | config.features_root, 170 | config.conv_size, 171 | stride=config.conv_stride, 172 | padding=config.conv_size//2), 173 | nn.ReLU(), nn.Dropout(p=1-config.keep_prob), 174 | nn.Conv2d(config.features_root, 175 | config.features_root, 176 | config.conv_size, 177 | stride=config.conv_stride, 178 | padding=config.conv_size//2), 179 | nn.ReLU(), nn.Dropout(p=1-config.keep_prob)) 180 | 181 | conv_extent_2 = nn.Sequential(nn.Conv2d(config.features_root, 182 | config.n_classes, 183 | 1), 184 | nn.Softmax(dim=1)) 185 | self.extent = nn.Sequential(conv_extent_1, conv_extent_2) 186 | 187 | def forward(self, x): 188 | """ 189 | NIVA model v2 on input features 190 | Args: 191 | x : tensor (B, C, W, H) 192 | Returns: 193 | extent: tensor (B, C_out, W, H), extent pseudo-probas 194 | boundary: tensor (B, C_out, W, H), boundary pseudo-probas 195 | distance: tensor (B, C_out, W, H), distance pseudo-probas 196 | """ 197 | x_1 = self.conv_1(x) 198 | x_p1 = self.conv_pool_1(nn.functional.pad(x_1, (0, 2, 0, 2))) 199 | 200 | x_2 = self.conv_2(x_p1) 201 | x_p2 = self.conv_pool_2(nn.functional.pad(x_2, (0, 2, 0, 2))) 202 | 203 | x_3 = self.conv_3(x_p2) 204 | x_p3 = self.conv_pool_3(nn.functional.pad(x_3, (0, 2, 0, 2))) 205 | 206 | x_4 = self.conv_4(x_p3) 207 | x_d4 = self.deconv_1(x_4) 208 | x_c4 = torch.cat([x_3, x_d4], 1) 209 | 210 | x_5 = self.conv_5(x_c4) 211 | x_d5 = self.deconv_2(x_5) 212 | x_c5 = torch.cat([x_2, x_d5], 1) 213 | 214 | x_6 = self.conv_6(x_c5) 215 | x_d6 = self.deconv_3(x_6) 216 | x_c6 = torch.cat([x_1, x_d6], 1) 217 | 218 | distance = self.distance(x_c6) 219 | 220 | cat_bound = torch.cat([x_c6, distance], 1) 221 | boundary = self.boundary(cat_bound) 222 | 223 | cat_extent = torch.cat([cat_bound, boundary], 1) 224 | extent = self.extent(cat_extent) 225 | 226 | return extent, boundary, distance 227 | 228 | -------------------------------------------------------------------------------- /sr/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from typing import List 4 | from datetime import datetime 5 | from dateutil.parser import parse 6 | import torch 7 | 8 | def denorm_s2(img: np.ndarray, denom_perc: np.lib.npyio.NpzFile, denorm_country: pd.DataFrame) -> np.ndarray: 9 | """ Denormalize normalized Sentinel-2 image """ 10 | norm_median = denorm_country[['median_0', 'median_1', 'median_2', 'median_3']].values 11 | norm_std = denorm_country[['std_0', 'std_1', 'std_2', 'std_3']].values 12 | 13 | # Denorm percentiles 14 | img = np.moveaxis(img.squeeze(), 0, 2)*(denom_perc['p99'] - denom_perc['p1']) + denom_perc['p1'] 15 | img = (img*norm_std)+norm_median 16 | 17 | img = np.expand_dims(np.moveaxis(img, 2, 0), 0) 18 | return img 19 | 20 | def get_closest_timestamp(timestamps: List[datetime], ref_timestamp: datetime) -> datetime: 21 | """ Get the timestamo closest to the reference timestamp """ 22 | closest_idx = 0 23 | for i, ts in enumerate(timestamps): 24 | if abs((ts - ref_timestamp).days) < abs((timestamps[closest_idx] - ref_timestamp).days): 25 | closest_idx = i 26 | return timestamps[closest_idx] 27 | 28 | def normalise_bands_perceptual(srs, hrs, timestamps, s2_p1, s2_p99, norm_std, norm_median, fd_means, fd_stds): 29 | 30 | srs = (srs.permute((0, 2, 3, 1)) * (s2_p99 - s2_p1) + s2_p1) * norm_std + norm_median 31 | srs = srs.permute(0, 3, 1, 2) 32 | 33 | hrs = (hrs.permute((0, 2, 3, 1)) * (s2_p99 - s2_p1) + s2_p1) * norm_std + norm_median 34 | hrs = hrs.permute(0, 3, 1, 2) 35 | 36 | srs_fd_normed = [] 37 | hrs_fd_normed = [] 38 | for hr, sr, ts in zip(hrs, srs, timestamps): 39 | 40 | month = parse(ts).month 41 | 42 | means = fd_means[month] 43 | stds = fd_stds[month] 44 | 45 | sr_fd_normalized = ((sr.permute(1, 2, 0)*10000 - means)/stds).permute(2, 0, 1) 46 | hr_fd_normalized = ((hr.permute(1, 2, 0)*10000 - means)/stds).permute(2, 0, 1) 47 | 48 | srs_fd_normed.append(sr_fd_normalized) 49 | hrs_fd_normed.append(hr_fd_normalized) 50 | 51 | 52 | srs = torch.stack(srs_fd_normed).float() 53 | hrs = torch.stack(hrs_fd_normed).float() 54 | 55 | return srs, hrs 56 | 57 | 58 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import shutil 4 | import pytest 5 | 6 | 7 | ROOT = os.path.dirname(os.path.realpath(__file__)) 8 | 9 | 10 | def pytest_addoption(parser): 11 | parser.addoption("--out_folder", action="store", default=None) 12 | 13 | 14 | def get_test_path(path, request): 15 | """ Constructs path for the test execution from the test file's name, which it gets from 16 | pytest.FixtureRequest (https://docs.pytest.org/en/latest/reference.html#request). 17 | """ 18 | test_name = re.findall("test_(.*).py", request.fspath.basename)[0] 19 | 20 | return os.path.join(ROOT, 'data', test_name, path) 21 | 22 | 23 | @pytest.fixture(scope='module') 24 | def input_folder(request): 25 | """ Creates the input folder path `dione-sr/tests/data/test_name/input`. 26 | """ 27 | return get_test_path('input', request) 28 | 29 | 30 | @pytest.fixture(scope='module') 31 | def compare_folder(request): 32 | """ Creates the compare folder path `dione-sr/tests/data/test_name/compare`. 33 | """ 34 | return get_test_path('compare', request) 35 | 36 | 37 | @pytest.fixture(scope='module') 38 | def output_folder(request): 39 | """ Creates the output folder path `dione-sr/tests/data/test_name/output`. 40 | 41 | It also cleans the output folder before the test runs. 42 | """ 43 | 44 | out_path = request.config.getoption("out_folder") 45 | 46 | if out_path is None: 47 | out_path = get_test_path('output', request) 48 | if os.path.exists(out_path): 49 | shutil.rmtree(out_path) 50 | 51 | os.makedirs(out_path) 52 | 53 | yield out_path 54 | 55 | # shutil.rmtree(OUTPUT_FOLDER) 56 | 57 | -------------------------------------------------------------------------------- /tests/data/data_loader/input/data_3m_eopatch-0277_66_2x.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sentinel-hub/multi-temporal-super-resolution/5ef642304a980db87bdb935a7a7450bd649f8912/tests/data/data_loader/input/data_3m_eopatch-0277_66_2x.npz -------------------------------------------------------------------------------- /tests/data/data_loader/input/data_eopatch-0277_66_2.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sentinel-hub/multi-temporal-super-resolution/5ef642304a980db87bdb935a7a7450bd649f8912/tests/data/data_loader/input/data_eopatch-0277_66_2.npz -------------------------------------------------------------------------------- /tests/data/data_loader/input/data_eopatch-0288_0_3.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sentinel-hub/multi-temporal-super-resolution/5ef642304a980db87bdb935a7a7450bd649f8912/tests/data/data_loader/input/data_eopatch-0288_0_3.npz -------------------------------------------------------------------------------- /tests/data/data_loader/input/deimos_min_max_norm.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sentinel-hub/multi-temporal-super-resolution/5ef642304a980db87bdb935a7a7450bd649f8912/tests/data/data_loader/input/deimos_min_max_norm.npz -------------------------------------------------------------------------------- /tests/data/data_loader/input/npz_info.pq: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sentinel-hub/multi-temporal-super-resolution/5ef642304a980db87bdb935a7a7450bd649f8912/tests/data/data_loader/input/npz_info.pq -------------------------------------------------------------------------------- /tests/data/data_loader/input/s2_min_max_norm.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sentinel-hub/multi-temporal-super-resolution/5ef642304a980db87bdb935a7a7450bd649f8912/tests/data/data_loader/input/s2_min_max_norm.npz -------------------------------------------------------------------------------- /tests/data/data_loader/input/s2_norm_per_country.pq: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sentinel-hub/multi-temporal-super-resolution/5ef642304a980db87bdb935a7a7450bd649f8912/tests/data/data_loader/input/s2_norm_per_country.pq -------------------------------------------------------------------------------- /tests/test_data_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | from functools import partial 3 | 4 | import numpy as np 5 | import pandas as pd 6 | 7 | from sr.data_loader import ImagesetDataset, augment 8 | 9 | 10 | def test_data_loader(input_folder): 11 | df = pd.read_parquet(os.path.join(input_folder, 'npz_info.pq')) 12 | 13 | min_max_de = np.load(os.path.join(input_folder, 'deimos_min_max_norm.npz')) 14 | min_max_s2 = np.load(os.path.join(input_folder, 's2_min_max_norm.npz')) 15 | 16 | norm_per_country = pd.read_parquet(os.path.join(input_folder, 's2_norm_per_country.pq')) 17 | 18 | imset_npz_files = df.singleton_npz_filename.values 19 | 20 | aug_fn = partial(augment, permute_timestamps=False) 21 | 22 | hrn_dataset = ImagesetDataset(imset_dir=input_folder, 23 | filesystem=None, 24 | imset_npz_files=imset_npz_files, 25 | time_first=True, 26 | normalize=True, 27 | country_norm_df=norm_per_country, 28 | norm_deimos_npz=min_max_de, 29 | norm_s2_npz=min_max_s2, 30 | channels_feats=[0, 1, 2, 3], 31 | channels_labels=[0, 1, 2, 3], 32 | n_views=8, 33 | padding='zeros', 34 | transform=aug_fn) 35 | 36 | sample = hrn_dataset[0] 37 | assert len(hrn_dataset) == 2 38 | assert sample.keys() == {'name', 'lr', 'hr', 'alphas'} 39 | assert sample['lr'].numpy().shape == (8, 4, 32, 32) 40 | assert sample['hr'].numpy().shape == (4, 128, 128) 41 | assert np.sum(sample['lr'].numpy()[-1]) == 0. 42 | assert sum(sample['alphas'].numpy()) == df.iloc[0].num_tstamps 43 | 44 | fra_dataset = ImagesetDataset(imset_dir=input_folder, 45 | filesystem=None, 46 | imset_npz_files=imset_npz_files, 47 | time_first=False, 48 | normalize=True, 49 | country_norm_df=norm_per_country, 50 | norm_deimos_npz=min_max_de, 51 | norm_s2_npz=min_max_s2, 52 | channels_feats=[0, 1, 2, 3], 53 | channels_labels=[0, 1, 2, 3], 54 | n_views=9, 55 | padding='repeat', 56 | transform=aug_fn) 57 | 58 | sample = fra_dataset[1] 59 | assert len(fra_dataset) == 2 60 | assert sample.keys() == {'name', 'lr', 'hr', 'alphas'} 61 | assert sample['lr'].numpy().shape == (4, 9, 32, 32) 62 | assert sample['hr'].numpy().shape == (4, 128, 128) 63 | assert np.sum(sample['lr'].numpy()[-1]) != 0. 64 | assert sum(sample['alphas'].numpy()) == 9 65 | 66 | com_dataset = ImagesetDataset(imset_dir=input_folder, 67 | filesystem=None, 68 | imset_npz_files=imset_npz_files, 69 | time_first=True, 70 | normalize=False, 71 | country_norm_df=norm_per_country, 72 | norm_deimos_npz=min_max_de, 73 | norm_s2_npz=min_max_s2, 74 | channels_feats=[0, 1, 2, 3], 75 | channels_labels=[0, 1, 2, 3], 76 | n_views=4, 77 | padding='zeros', 78 | transform=None) 79 | 80 | sample_npz = np.load(os.path.join(input_folder, imset_npz_files[0]), allow_pickle=True) 81 | sample = com_dataset[0] 82 | 83 | np.testing.assert_equal(sample_npz['features'][-1][..., 0], sample['lr'][-1][0].numpy()) 84 | np.testing.assert_equal(sample_npz['labels'][..., 0], sample['hr'][0].numpy()) 85 | 86 | -------------------------------------------------------------------------------- /tests/test_metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | 5 | import pytest 6 | 7 | from sr.metrics import METRICS, calculate_metrics, minshift_loss 8 | 9 | 10 | def test_metrics(): 11 | x = torch.ones(5, 4, 64, 64) 12 | factor = 0.7 13 | 14 | with pytest.raises(AssertionError): 15 | calculate_metrics(x, x+factor, metrics=('L2', 'MSE')) 16 | 17 | mae_loss = METRICS['MAE'](x, x + factor) 18 | mse_loss = METRICS['MSE'](x, x + factor) 19 | psnr_loss = METRICS['PSNR'](x, x + factor) 20 | ssim_loss = METRICS['SSIM'](x, x + factor) 21 | mixed_loss = METRICS['MIXED'](x, x + factor) 22 | 23 | np.testing.assert_almost_equal(torch.mean(mae_loss).numpy(), factor, decimal=5) 24 | np.testing.assert_almost_equal(torch.mean(mse_loss).numpy(), factor**2, decimal=5) 25 | np.testing.assert_almost_equal(torch.mean(psnr_loss).numpy(), -10*np.log10(factor**2), decimal=4) 26 | np.testing.assert_almost_equal(torch.mean(ssim_loss).numpy(), 0.87358, decimal=5) 27 | np.testing.assert_almost_equal(torch.mean(mixed_loss).numpy(), 28 | torch.mean(mse_loss + .1*(-1 * ssim_loss + 1)).numpy()) 29 | 30 | 31 | def test_calculate_metrics(): 32 | x = torch.rand(5, 3, 64, 64) 33 | y = torch.rand(5, 3, 64, 64) 34 | 35 | factor = .3 36 | 37 | scores = calculate_metrics(x, y, list(METRICS.keys())) 38 | 39 | assert set(scores.keys()) == set(METRICS.keys()) 40 | 41 | for metric in METRICS: 42 | np.testing.assert_equal(scores[metric].numpy(), METRICS[metric](x, y).numpy()) 43 | 44 | scores = calculate_metrics(x, x+factor, list(METRICS.keys()), apply_correction=True) 45 | 46 | results = dict(MAE=0.0, MSE=0.0, PSNR=80.0, SSIM=1.0, MIXED=0.0) 47 | 48 | for metric in METRICS: 49 | np.testing.assert_almost_equal(np.mean(scores[metric].numpy()), results[metric], decimal=4) 50 | 51 | assert type(scores) == dict 52 | assert type(calculate_metrics(x, x+factor, metrics='MAE')) == torch.Tensor 53 | assert type(calculate_metrics(x, x+factor, metrics='MIXED')) == torch.Tensor 54 | 55 | 56 | def test_minshift_loss(): 57 | factor = .1 58 | batch = 2 59 | h, w = 64, 64 60 | 61 | x = torch.rand(batch, 1, h, w) 62 | y = x + factor 63 | 64 | shifts = 6 65 | 66 | y[:, :, :shifts//2, :] = -10 67 | y[:, :, :, :shifts//2] = -10 68 | 69 | scores, _ = minshift_loss(x, y, metric='MSE', shifts=shifts) 70 | 71 | np.testing.assert_almost_equal(scores.numpy(), factor**2) 72 | 73 | scores, indices = minshift_loss(x, y, metric='PSNR', shifts=shifts, apply_correction=True) 74 | 75 | assert scores.shape == (batch,) 76 | assert indices.shape == (batch, 4) 77 | np.testing.assert_almost_equal(scores.mean().numpy(), 80.0) 78 | np.testing.assert_equal(indices.numpy()[0, :], (shifts//2, h-shifts//2, shifts//2, w-shifts//2)) 79 | --------------------------------------------------------------------------------