├── LICENSE.md ├── LR_models ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ └── siamese_model_rgb.cpython-36.pyc ├── siamese_model_binary.py ├── siamese_model_binary_layerwise.py └── siamese_model_rgb.py ├── README.md ├── __init__.py ├── bass_model_fit_adoption_curves.py ├── bass_model_parameter_phase_analysis.ipynb ├── bass_model_parameter_regression.ipynb ├── generate_anchor_image_dict.ipynb ├── hp_search_HR.py ├── hp_search_LR_rgb.py ├── hp_search_ood_multilabels.py ├── predict_HR.py ├── predict_LR_rgb.py ├── predict_installation_year_from_image_sequences.ipynb ├── predict_ood_multilabels.py ├── requirements.txt └── utils ├── .DS_Store ├── .ipynb_checkpoints ├── binarize-checkpoint.ipynb └── collect_image_info-checkpoint.ipynb ├── __init__.py ├── __pycache__ ├── __init__.cpython-36.pyc └── image_dataset.cpython-36.pyc ├── image_dataset.py └── inception_modified.py /LICENSE.md: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /LR_models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangzhecheng/DeepSolar_timelapse/b9f0e0e81b901426a166f55a3ba6ab53ea260407/LR_models/__init__.py -------------------------------------------------------------------------------- /LR_models/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangzhecheng/DeepSolar_timelapse/b9f0e0e81b901426a166f55a3ba6ab53ea260407/LR_models/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /LR_models/__pycache__/siamese_model_rgb.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangzhecheng/DeepSolar_timelapse/b9f0e0e81b901426a166f55a3ba6ab53ea260407/LR_models/__pycache__/siamese_model_rgb.cpython-36.pyc -------------------------------------------------------------------------------- /LR_models/siamese_model_binary.py: -------------------------------------------------------------------------------- 1 | # -*-coding:utf-8 -*- 2 | 3 | from __future__ import print_function 4 | from __future__ import division 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | from torch.utils.data import Dataset, DataLoader 9 | import torchvision 10 | from torchvision import datasets, models, transforms, utils 11 | import torchvision.transforms.functional as TF 12 | 13 | from tqdm import tqdm 14 | import numpy as np 15 | import pandas as pd 16 | import pickle 17 | import matplotlib.pyplot as plt 18 | # import skimage 19 | # import skimage.io 20 | # import skimage.transform 21 | from PIL import Image 22 | import time 23 | import os 24 | from os.path import join, exists 25 | import copy 26 | import random 27 | from collections import OrderedDict 28 | from sklearn.metrics import r2_score 29 | 30 | 31 | import torch.nn.functional as F 32 | from torchvision.models import Inception3, resnet18, ResNet 33 | from collections import namedtuple 34 | 35 | 36 | def conv3x3(in_planes, out_planes, stride=1): 37 | """3x3 convolution with padding""" 38 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 39 | padding=1, bias=False) 40 | 41 | 42 | class BasicBlock(nn.Module): 43 | expansion = 1 44 | 45 | def __init__(self, inplanes, planes, stride=1, downsample=None): 46 | super(BasicBlock, self).__init__() 47 | self.conv1 = conv3x3(inplanes, planes, stride) 48 | self.bn1 = nn.BatchNorm2d(planes) 49 | self.relu = nn.ReLU(inplace=True) 50 | self.conv2 = conv3x3(planes, planes) 51 | self.bn2 = nn.BatchNorm2d(planes) 52 | self.downsample = downsample 53 | self.stride = stride 54 | 55 | def forward(self, x): 56 | residual = x 57 | 58 | out = self.conv1(x) 59 | out = self.bn1(out) 60 | out = self.relu(out) 61 | 62 | out = self.conv2(out) 63 | out = self.bn2(out) 64 | 65 | if self.downsample is not None: 66 | residual = self.downsample(x) 67 | 68 | out += residual 69 | out = self.relu(out) 70 | 71 | return out 72 | 73 | 74 | class resnet18_modified(ResNet): 75 | def forward(self, x): 76 | x = self.conv1(x) 77 | x = self.bn1(x) 78 | x = self.relu(x) 79 | x = self.maxpool(x) 80 | 81 | x = self.layer1(x) 82 | # 64 x 75 x 75 83 | fm38 = self.layer2(x) 84 | # 128 x 38 x 38 85 | fm19 = self.layer3(fm38) 86 | # 256 x 19 x 19 87 | # fm10 = self.layer4(fm19) 88 | # 512 x 10 x 10 89 | x256 = self.avgpool(fm19) 90 | x = x256.view(x256.size(0), -1) 91 | 92 | # x512 = self.avgpool(fm10) 93 | # x = x512.view(x512.size(0), -1) 94 | # x = self.fc(x) 95 | 96 | return x, fm19 97 | # return x256, fm19 98 | 99 | 100 | class sn_depthwise_cc(nn.Module): 101 | def __init__(self, nconvs=2, nfilters=256): 102 | super(sn_depthwise_cc, self).__init__() 103 | self.net1 = resnet18_modified(BasicBlock, [2, 2, 2, 2]) 104 | self.net1.fc = nn.Linear(256, 128) 105 | 106 | self.nconvs = nconvs 107 | assert self.nconvs in [0, 1, 2, 3] 108 | 109 | if self.nconvs in [1, 2, 3]: 110 | self.conv_combo_1 = nn.Sequential(nn.Conv2d( 111 | 256, nfilters, kernel_size=3, padding=1, stride=1), nn.BatchNorm2d(nfilters), nn.ReLU(inplace=True)) 112 | 113 | if self.nconvs in [2, 3]: 114 | self.conv_combo_2 = nn.Sequential(nn.Conv2d( 115 | nfilters, nfilters, kernel_size=3, padding=1, stride=1), nn.BatchNorm2d(nfilters), 116 | nn.ReLU(inplace=True)) 117 | 118 | if self.nconvs == 3: 119 | self.conv_combo_3 = nn.Sequential(nn.Conv2d( 120 | nfilters, nfilters, kernel_size=3, padding=1, stride=1), nn.BatchNorm2d(nfilters), 121 | nn.ReLU(inplace=True)) 122 | 123 | if self.nconvs == 0: 124 | self.fc = nn.Linear(256, 2) 125 | else: 126 | self.fc = nn.Linear(nfilters, 2) 127 | 128 | def forward(self, img1, img2): 129 | img1 = torch.cat([img1, img1, img1], 1) 130 | img2 = torch.cat([img2, img2, img2], 1) 131 | 132 | x1, fm1 = self.net1(img1) 133 | x2, fm2 = self.net1(img2) 134 | 135 | # depth-wise cross correlation 136 | nchannels = fm1.size()[1] 137 | fm1 = fm1.reshape(-1, fm1.size()[2], fm1.size()[3]) 138 | fm1 = fm1.unsqueeze(0) 139 | fm2 = fm2.reshape(-1, fm2.size()[2], fm2.size()[3]) 140 | fm2 = fm2.unsqueeze(0).permute(1, 0, 2, 3) 141 | new_vec = F.conv2d(fm1, fm2, padding=9, stride=1, groups=fm2.size()[0]).squeeze() 142 | new_vec = new_vec.reshape(-1, nchannels, new_vec.size()[1], new_vec.size()[2]) 143 | out = F.relu(new_vec) 144 | 145 | # convolution layers 146 | if self.nconvs in [1, 2, 3]: 147 | out = self.conv_combo_1(out) 148 | if self.nconvs in [2, 3]: 149 | out = self.conv_combo_2(out) 150 | if self.nconvs == 3: 151 | out = self.conv_combo_3(out) 152 | 153 | # global average pooling 154 | out = F.adaptive_avg_pool2d(out, (1, 1)) 155 | out = out.view(out.size(0), -1) 156 | out = self.fc(out) 157 | return out 158 | 159 | 160 | class sn_depthwise_cc_1x1(nn.Module): 161 | def __init__(self, nconvs=2, nfilters=256): 162 | super(sn_depthwise_cc_1x1, self).__init__() 163 | self.net1 = resnet18_modified(BasicBlock, [2, 2, 2, 2]) 164 | self.net1.fc = nn.Linear(256, 128) 165 | 166 | self.nconvs = nconvs 167 | assert self.nconvs in [0, 1, 2, 3] 168 | 169 | if self.nconvs in [1, 2, 3]: 170 | self.conv_combo_1 = nn.Sequential(nn.Conv2d( 171 | 256, nfilters, kernel_size=1, padding=0, stride=1), nn.BatchNorm2d(nfilters), nn.ReLU(inplace=True)) 172 | 173 | if self.nconvs in [2, 3]: 174 | self.conv_combo_2 = nn.Sequential(nn.Conv2d( 175 | nfilters, nfilters, kernel_size=1, padding=0, stride=1), nn.BatchNorm2d(nfilters), nn.ReLU(inplace=True)) 176 | 177 | if self.nconvs == 3: 178 | self.conv_combo_3 = nn.Sequential(nn.Conv2d( 179 | nfilters, nfilters, kernel_size=1, padding=0, stride=1), nn.BatchNorm2d(nfilters), nn.ReLU(inplace=True)) 180 | 181 | if self.nconvs == 0: 182 | self.fc = nn.Linear(256, 2) 183 | else: 184 | self.fc = nn.Linear(nfilters, 2) 185 | 186 | def forward(self, img1, img2): 187 | img1 = torch.cat([img1, img1, img1], 1) 188 | img2 = torch.cat([img2, img2, img2], 1) 189 | 190 | x1, fm1 = self.net1(img1) 191 | x2, fm2 = self.net1(img2) 192 | 193 | # depth-wise cross correlation 194 | nchannels = fm1.size()[1] 195 | fm1 = fm1.reshape(-1, fm1.size()[2], fm1.size()[3]) 196 | fm1 = fm1.unsqueeze(0) 197 | fm2 = fm2.reshape(-1, fm2.size()[2], fm2.size()[3]) 198 | fm2 = fm2.unsqueeze(0).permute(1, 0, 2, 3) 199 | new_vec = F.conv2d(fm1, fm2, padding=9, stride=1, groups=fm2.size()[0]).squeeze() 200 | new_vec = new_vec.reshape(-1, nchannels, new_vec.size()[1], new_vec.size()[2]) 201 | out = F.relu(new_vec) 202 | 203 | # convolution layers 204 | if self.nconvs in [1, 2, 3]: 205 | out = self.conv_combo_1(out) 206 | if self.nconvs in [2, 3]: 207 | out = self.conv_combo_2(out) 208 | if self.nconvs == 3: 209 | out = self.conv_combo_3(out) 210 | 211 | # global average pooling 212 | out = F.adaptive_avg_pool2d(out, (1, 1)) 213 | out = out.view(out.size(0), -1) 214 | out = self.fc(out) 215 | return out 216 | 217 | -------------------------------------------------------------------------------- /LR_models/siamese_model_binary_layerwise.py: -------------------------------------------------------------------------------- 1 | # -*-coding:utf-8 -*- 2 | 3 | from __future__ import print_function 4 | from __future__ import division 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | from torch.utils.data import Dataset, DataLoader 9 | import torchvision 10 | from torchvision import datasets, models, transforms, utils 11 | import torchvision.transforms.functional as TF 12 | 13 | from tqdm import tqdm 14 | import numpy as np 15 | import pandas as pd 16 | import pickle 17 | import matplotlib.pyplot as plt 18 | # import skimage 19 | # import skimage.io 20 | # import skimage.transform 21 | from PIL import Image 22 | import time 23 | import os 24 | from os.path import join, exists 25 | import copy 26 | import random 27 | from collections import OrderedDict 28 | from sklearn.metrics import r2_score 29 | 30 | 31 | import torch.nn.functional as F 32 | from torchvision.models import Inception3, resnet18, ResNet 33 | from collections import namedtuple 34 | 35 | 36 | def conv3x3(in_planes, out_planes, stride=1): 37 | """3x3 convolution with padding""" 38 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 39 | padding=1, bias=False) 40 | 41 | 42 | class BasicBlock(nn.Module): 43 | expansion = 1 44 | 45 | def __init__(self, inplanes, planes, stride=1, downsample=None): 46 | super(BasicBlock, self).__init__() 47 | self.conv1 = conv3x3(inplanes, planes, stride) 48 | self.bn1 = nn.BatchNorm2d(planes) 49 | self.relu = nn.ReLU(inplace=True) 50 | self.conv2 = conv3x3(planes, planes) 51 | self.bn2 = nn.BatchNorm2d(planes) 52 | self.downsample = downsample 53 | self.stride = stride 54 | 55 | def forward(self, x): 56 | residual = x 57 | 58 | out = self.conv1(x) 59 | out = self.bn1(out) 60 | out = self.relu(out) 61 | 62 | out = self.conv2(out) 63 | out = self.bn2(out) 64 | 65 | if self.downsample is not None: 66 | residual = self.downsample(x) 67 | 68 | out += residual 69 | out = self.relu(out) 70 | 71 | return out 72 | 73 | 74 | class resnet18_modified(ResNet): 75 | def forward(self, x): 76 | x = self.conv1(x) 77 | x = self.bn1(x) 78 | x = self.relu(x) 79 | x = self.maxpool(x) 80 | 81 | x = self.layer1(x) 82 | # 64 x 75 x 75 83 | fm38 = self.layer2(x) 84 | # 128 x 38 x 38 85 | fm19 = self.layer3(fm38) 86 | # 256 x 19 x 19 87 | # fm10 = self.layer4(fm19) 88 | # 512 x 10 x 10 89 | x256 = self.avgpool(fm19) 90 | x = x256.view(x256.size(0), -1) 91 | 92 | # x512 = self.avgpool(fm10) 93 | # x = x512.view(x512.size(0), -1) 94 | # x = self.fc(x) 95 | 96 | return x, fm38, fm19 97 | # return x256, fm19 98 | 99 | 100 | class resnet18_modified_2(ResNet): 101 | def forward(self, x): 102 | x = self.conv1(x) 103 | x = self.bn1(x) 104 | x = self.relu(x) 105 | x = self.maxpool(x) 106 | 107 | fm75 = self.layer1(x) 108 | # 64 x 75 x 75 109 | fm38 = self.layer2(fm75) 110 | # 128 x 38 x 38 111 | fm19 = self.layer3(fm38) 112 | # 256 x 19 x 19 113 | # fm10 = self.layer4(fm19) 114 | # 512 x 10 x 10 115 | x256 = self.avgpool(fm19) 116 | x = x256.view(x256.size(0), -1) 117 | 118 | # x512 = self.avgpool(fm10) 119 | # x = x512.view(x512.size(0), -1) 120 | # x = self.fc(x) 121 | 122 | return x, fm75, fm38, fm19 123 | # return x256, fm19 124 | 125 | 126 | class resnet18_modified_3(ResNet): 127 | def forward(self, x): 128 | x = self.conv1(x) 129 | x = self.bn1(x) 130 | x = self.relu(x) 131 | x = self.maxpool(x) 132 | 133 | fm75 = self.layer1(x) 134 | # 64 x 75 x 75 135 | fm38 = self.layer2(fm75) 136 | # 128 x 38 x 38 137 | fm19 = self.layer3(fm38) 138 | # 256 x 19 x 19 139 | fm10 = self.layer4(fm19) 140 | # 512 x 10 x 10 141 | # x256 = self.avgpool(fm19) 142 | # x = x256.view(x256.size(0), -1) 143 | 144 | # x512 = self.avgpool(fm10) 145 | # x = x512.view(x512.size(0), -1) 146 | # x = self.fc(x) 147 | 148 | return fm75, fm38, fm19, fm10 149 | 150 | 151 | def XCross_depthwise(fm1, fm2, padding): 152 | """ 153 | :param fm1: inputs 154 | :param fm2: filters 155 | :param padding: padding in cross correlation layer 156 | :return: similarity map after cross correlation 157 | """ 158 | nchannels = fm1.size()[1] 159 | fm1 = fm1.reshape(-1, fm1.size()[2], fm1.size()[3]) 160 | fm1 = fm1.unsqueeze(0) 161 | fm2 = fm2.reshape(-1, fm2.size()[2], fm2.size()[3]) 162 | fm2 = fm2.unsqueeze(0).permute(1, 0, 2, 3) 163 | out = F.conv2d(fm1, fm2, padding=padding, stride=1, groups=fm2.size()[0]).squeeze() 164 | out = out.reshape(-1, nchannels, out.size()[1], out.size()[2]) 165 | return out 166 | 167 | 168 | def XCross(fm1, fm2, padding): 169 | """ 170 | :param fm1: inputs 171 | :param fm2: filters 172 | :param padding: padding in cross correlation layer 173 | :return: similarity map after cross correlation 174 | """ 175 | nchannels = fm1.size()[1] 176 | fm1 = fm1.reshape(-1, fm1.size()[2], fm1.size()[3]) 177 | fm1 = fm1.unsqueeze(0) 178 | fm2 = fm2.reshape(-1, fm2.size()[2], fm2.size()[3]) 179 | fm2 = fm2.unsqueeze(0).permute(1, 0, 2, 3) 180 | out = F.conv2d(fm1, fm2, padding=padding, stride=1, groups=fm2.size()[0]).squeeze() 181 | out = out.reshape(-1, nchannels, out.size()[1], out.size()[2]) 182 | return torch.sum(out, keepdim=True, dim=1) # batch_size x 1 x H x W 183 | 184 | 185 | class sn_cc_layerwise(nn.Module): 186 | def __init__(self, nlayers=2, hidden=256): 187 | super(sn_cc_layerwise, self).__init__() 188 | self.net1 = resnet18_modified(BasicBlock, [2, 2, 2, 2]) 189 | self.net1.fc = nn.Linear(256, 128) 190 | 191 | self.extra_1x1_conv = nn.Conv2d(256, 128, kernel_size=1, stride=1, padding=0, bias=False) 192 | 193 | self.nlayers = nlayers 194 | assert self.nlayers in [1, 2, 3] 195 | 196 | if self.nlayers in [2, 3]: 197 | self.fc_combo_1 = nn.Sequential(nn.Linear(722, hidden), nn.ReLU(inplace=True)) 198 | 199 | if self.nlayers == 3: 200 | self.fc_combo_2 = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(inplace=True)) 201 | 202 | if self.nlayers == 1: 203 | self.fc_final = nn.Linear(722, 2) 204 | else: 205 | self.fc_final = nn.Linear(hidden, 2) 206 | 207 | def forward(self, img1, img2): 208 | img1 = torch.cat([img1, img1, img1], 1) 209 | img2 = torch.cat([img2, img2, img2], 1) 210 | 211 | x1, fm38_1, fm19_1 = self.net1(img1) 212 | x2, fm38_2, fm19_2 = self.net1(img2) 213 | 214 | # cross correlation 215 | sm38 = XCross(fm38_1, fm38_2, padding=19) # batch_size * 1 * 39 * 39 216 | sm19 = XCross(fm19_1, fm19_2, padding=9) # batch_size * 1 * 19 * 19 217 | 218 | # aggregate 219 | sm38 = F.adaptive_avg_pool2d(sm38, (sm19.size()[2], sm19.size()[3])) # batch_size * 128 * 19 * 19 220 | sm38 = sm38.view(sm38.size()[0], -1) # batch_size * 361 221 | sm19 = sm19.view(sm38.size()[0], -1) # batch_size * 361 222 | out = torch.cat([sm38, sm19], dim=1) # batch_size * 722 223 | out = F.relu(out) 224 | 225 | # fc layers 226 | if self.nlayers in [2, 3]: 227 | out = self.fc_combo_1(out) 228 | if self.nlayers == 3: 229 | out = self.fc_combo_2(out) 230 | out = self.fc_final(out) 231 | return out 232 | 233 | 234 | class sn_depthwise_cc_layerwise(nn.Module): 235 | def __init__(self, nconvs=2, nfilters=256): 236 | super(sn_depthwise_cc_layerwise, self).__init__() 237 | self.net1 = resnet18_modified(BasicBlock, [2, 2, 2, 2]) 238 | self.net1.fc = nn.Linear(256, 128) 239 | 240 | self.extra_1x1_conv = nn.Conv2d(256, 128, kernel_size=1, stride=1, padding=0, bias=False) 241 | 242 | self.nconvs = nconvs 243 | assert self.nconvs in [0, 1, 2, 3] 244 | 245 | if self.nconvs in [1, 2, 3]: 246 | self.conv_combo_1 = nn.Sequential(nn.Conv2d( 247 | 256, nfilters, kernel_size=1, padding=0, stride=1), nn.BatchNorm2d(nfilters), nn.ReLU(inplace=True)) 248 | 249 | if self.nconvs in [2, 3]: 250 | self.conv_combo_2 = nn.Sequential(nn.Conv2d( 251 | nfilters, nfilters, kernel_size=1, padding=0, stride=1), nn.BatchNorm2d(nfilters), nn.ReLU(inplace=True)) 252 | 253 | if self.nconvs == 3: 254 | self.conv_combo_3 = nn.Sequential(nn.Conv2d( 255 | nfilters, nfilters, kernel_size=1, padding=0, stride=1), nn.BatchNorm2d(nfilters), nn.ReLU(inplace=True)) 256 | 257 | if self.nconvs == 0: 258 | self.fc = nn.Linear(256, 2) 259 | else: 260 | self.fc = nn.Linear(nfilters, 2) 261 | 262 | def forward(self, img1, img2): 263 | img1 = torch.cat([img1, img1, img1], 1) 264 | img2 = torch.cat([img2, img2, img2], 1) 265 | 266 | x1, fm38_1, fm19_1 = self.net1(img1) 267 | x2, fm38_2, fm19_2 = self.net1(img2) 268 | 269 | # change the number of channels 270 | fm19_1 = self.extra_1x1_conv(fm19_1) # batch_size * 128 * 19 * 19 271 | fm19_2 = self.extra_1x1_conv(fm19_2) # batch_size * 128 * 19 * 19 272 | 273 | # depth-wise cross correlation 274 | sm38 = XCross_depthwise(fm38_1, fm38_2, padding=19) # batch_size * 128 * 38 * 38 275 | sm19 = XCross_depthwise(fm19_1, fm19_2, padding=9) # batch_size * 128 * 19 * 19 276 | 277 | # aggregate along the channel 278 | sm38_downsampled = F.adaptive_avg_pool2d(sm38, (sm19.size()[2], sm19.size()[3])) # batch_size * 128 * 19 * 19 279 | out = torch.cat([sm38_downsampled, sm19], 1) # batch_size * 256 * 19 * 19 280 | out = F.relu(out) 281 | 282 | # convolution layers 283 | if self.nconvs in [1, 2, 3]: 284 | out = self.conv_combo_1(out) 285 | if self.nconvs in [2, 3]: 286 | out = self.conv_combo_2(out) 287 | if self.nconvs == 3: 288 | out = self.conv_combo_3(out) 289 | 290 | # global average pooling 291 | out = F.adaptive_avg_pool2d(out, (1, 1)) 292 | out = out.view(out.size(0), -1) 293 | out = self.fc(out) 294 | return out 295 | 296 | 297 | class sn_depthwise_cc_layerwise_3x3(nn.Module): 298 | def __init__(self, nconvs=2, nfilters=256): 299 | super(sn_depthwise_cc_layerwise_3x3, self).__init__() 300 | self.net1 = resnet18_modified(BasicBlock, [2, 2, 2, 2]) 301 | self.net1.fc = nn.Linear(256, 128) 302 | 303 | self.extra_1x1_conv = nn.Conv2d(256, 128, kernel_size=1, stride=1, padding=0, bias=False) 304 | 305 | self.nconvs = nconvs 306 | assert self.nconvs in [0, 1, 2, 3] 307 | 308 | if self.nconvs in [1, 2, 3]: 309 | self.conv_combo_1 = nn.Sequential(nn.Conv2d( 310 | 256, nfilters, kernel_size=3, padding=1, stride=1), nn.BatchNorm2d(nfilters), nn.ReLU(inplace=True)) 311 | 312 | if self.nconvs in [2, 3]: 313 | self.conv_combo_2 = nn.Sequential(nn.Conv2d( 314 | nfilters, nfilters, kernel_size=3, padding=1, stride=1), nn.BatchNorm2d(nfilters), nn.ReLU(inplace=True)) 315 | 316 | if self.nconvs == 3: 317 | self.conv_combo_3 = nn.Sequential(nn.Conv2d( 318 | nfilters, nfilters, kernel_size=3, padding=1, stride=1), nn.BatchNorm2d(nfilters), nn.ReLU(inplace=True)) 319 | 320 | if self.nconvs == 0: 321 | self.fc = nn.Linear(256, 2) 322 | else: 323 | self.fc = nn.Linear(nfilters, 2) 324 | 325 | def forward(self, img1, img2): 326 | img1 = torch.cat([img1, img1, img1], 1) 327 | img2 = torch.cat([img2, img2, img2], 1) 328 | 329 | x1, fm38_1, fm19_1 = self.net1(img1) 330 | x2, fm38_2, fm19_2 = self.net1(img2) 331 | 332 | # change the number of channels 333 | fm19_1 = self.extra_1x1_conv(fm19_1) # batch_size * 128 * 19 * 19 334 | fm19_2 = self.extra_1x1_conv(fm19_2) # batch_size * 128 * 19 * 19 335 | 336 | # depth-wise cross correlation 337 | sm38 = XCross_depthwise(fm38_1, fm38_2, padding=19) # batch_size * 128 * 38 * 38 338 | sm19 = XCross_depthwise(fm19_1, fm19_2, padding=9) # batch_size * 128 * 19 * 19 339 | 340 | # aggregate along the channel 341 | sm38_downsampled = F.adaptive_avg_pool2d(sm38, (sm19.size()[2], sm19.size()[3])) # batch_size * 128 * 19 * 19 342 | out = torch.cat([sm38_downsampled, sm19], 1) # batch_size * 256 * 19 * 19 343 | out = F.relu(out) 344 | 345 | # convolution layers 346 | if self.nconvs in [1, 2, 3]: 347 | out = self.conv_combo_1(out) 348 | if self.nconvs in [2, 3]: 349 | out = self.conv_combo_2(out) 350 | if self.nconvs == 3: 351 | out = self.conv_combo_3(out) 352 | 353 | # global average pooling 354 | out = F.adaptive_avg_pool2d(out, (1, 1)) 355 | out = out.view(out.size(0), -1) 356 | out = self.fc(out) 357 | return out 358 | 359 | 360 | 361 | class sn_depthwise_cc_layerwise_3x3_3layers(nn.Module): 362 | """ 363 | Depthwise cross-correlation, layerwise aggregation (aggregate similarity maps from 3 layers), with 3x3 convolution kernel. 364 | """ 365 | def __init__(self, nconvs=2, nfilters=256): 366 | super(sn_depthwise_cc_layerwise_3x3_3layers, self).__init__() 367 | self.net1 = resnet18_modified_2(BasicBlock, [2, 2, 2, 2]) 368 | self.net1.fc = nn.Linear(256, 128) 369 | 370 | self.extra_1x1_conv_1 = nn.Conv2d(64, 128, kernel_size=1, stride=1, padding=0, bias=False) 371 | self.extra_1x1_conv_3 = nn.Conv2d(256, 128, kernel_size=1, stride=1, padding=0, bias=False) 372 | 373 | self.nconvs = nconvs 374 | assert self.nconvs in [0, 1, 2, 3] 375 | 376 | if self.nconvs in [1, 2, 3]: 377 | self.conv_combo_1 = nn.Sequential(nn.Conv2d( 378 | 384, nfilters, kernel_size=3, padding=1, stride=1), nn.BatchNorm2d(nfilters), nn.ReLU(inplace=True)) 379 | 380 | if self.nconvs in [2, 3]: 381 | self.conv_combo_2 = nn.Sequential(nn.Conv2d( 382 | nfilters, nfilters, kernel_size=3, padding=1, stride=1), nn.BatchNorm2d(nfilters), nn.ReLU(inplace=True)) 383 | 384 | if self.nconvs == 3: 385 | self.conv_combo_3 = nn.Sequential(nn.Conv2d( 386 | nfilters, nfilters, kernel_size=3, padding=1, stride=1), nn.BatchNorm2d(nfilters), nn.ReLU(inplace=True)) 387 | 388 | if self.nconvs == 0: 389 | self.fc = nn.Linear(384, 2) 390 | else: 391 | self.fc = nn.Linear(nfilters, 2) 392 | 393 | def forward(self, img1, img2): 394 | img1 = torch.cat([img1, img1, img1], 1) 395 | img2 = torch.cat([img2, img2, img2], 1) 396 | 397 | x1, fm75_1, fm38_1, fm19_1 = self.net1(img1) 398 | x2, fm75_2, fm38_2, fm19_2 = self.net1(img2) 399 | 400 | # change the number of channels 401 | fm75_1 = self.extra_1x1_conv_1(fm75_1) # batch_size * 128 * 75 * 75 402 | fm75_2 = self.extra_1x1_conv_1(fm75_2) # batch_size * 128 * 75 * 75 403 | 404 | fm19_1 = self.extra_1x1_conv_3(fm19_1) # batch_size * 128 * 19 * 19 405 | fm19_2 = self.extra_1x1_conv_3(fm19_2) # batch_size * 128 * 19 * 19 406 | 407 | # depth-wise cross correlation 408 | sm75 = XCross_depthwise(fm75_1, fm75_2, padding=37) # batch_size * 128 * 75 * 75 409 | sm38 = XCross_depthwise(fm38_1, fm38_2, padding=19) # batch_size * 128 * 39 * 39 410 | sm19 = XCross_depthwise(fm19_1, fm19_2, padding=9) # batch_size * 128 * 19 * 19 411 | 412 | # aggregate along the channel 413 | sm75_downsampled = F.adaptive_avg_pool2d(sm75, (sm19.size()[2], sm19.size()[3])) # batch_size * 128 * 19 * 19 414 | sm38_downsampled = F.adaptive_avg_pool2d(sm38, (sm19.size()[2], sm19.size()[3])) # batch_size * 128 * 19 * 19 415 | out = torch.cat([sm75_downsampled, sm38_downsampled, sm19], 1) # batch_size * 384 * 19 * 19 416 | out = F.relu(out) 417 | 418 | # convolution layers 419 | if self.nconvs in [1, 2, 3]: 420 | out = self.conv_combo_1(out) 421 | if self.nconvs in [2, 3]: 422 | out = self.conv_combo_2(out) 423 | if self.nconvs == 3: 424 | out = self.conv_combo_3(out) 425 | 426 | # global average pooling 427 | out = F.adaptive_avg_pool2d(out, (1, 1)) 428 | out = out.view(out.size(0), -1) 429 | out = self.fc(out) 430 | return out 431 | 432 | 433 | class sn_depthwise_cc_layerwise_last2(nn.Module): 434 | def __init__(self, nconvs=2, nfilters=256): 435 | super(sn_depthwise_cc_layerwise_last2, self).__init__() 436 | self.net1 = resnet18_modified_3(BasicBlock, [2, 2, 2, 2]) 437 | self.net1.fc = nn.Linear(256, 128) 438 | 439 | self.extra_1x1_conv = nn.Conv2d(512, 256, kernel_size=1, stride=1, padding=0, bias=False) 440 | 441 | self.nconvs = nconvs 442 | assert self.nconvs in [0, 1, 2, 3] 443 | 444 | if self.nconvs in [1, 2, 3]: 445 | self.conv_combo_1 = nn.Sequential(nn.Conv2d( 446 | 512, nfilters, kernel_size=3, padding=1, stride=1), nn.BatchNorm2d(nfilters), nn.ReLU(inplace=True)) 447 | 448 | if self.nconvs in [2, 3]: 449 | self.conv_combo_2 = nn.Sequential(nn.Conv2d( 450 | nfilters, nfilters, kernel_size=3, padding=1, stride=1), nn.BatchNorm2d(nfilters), nn.ReLU(inplace=True)) 451 | 452 | if self.nconvs == 3: 453 | self.conv_combo_3 = nn.Sequential(nn.Conv2d( 454 | nfilters, nfilters, kernel_size=3, padding=1, stride=1), nn.BatchNorm2d(nfilters), nn.ReLU(inplace=True)) 455 | 456 | if self.nconvs == 0: 457 | self.fc = nn.Linear(512, 2) 458 | else: 459 | self.fc = nn.Linear(nfilters, 2) 460 | 461 | def forward(self, img1, img2): 462 | img1 = torch.cat([img1, img1, img1], 1) 463 | img2 = torch.cat([img2, img2, img2], 1) 464 | 465 | _, _, fm19_1, fm10_1 = self.net1(img1) 466 | _, _, fm19_2, fm10_2 = self.net1(img2) 467 | 468 | # change the number of channels 469 | fm10_1 = self.extra_1x1_conv(fm10_1) # batch_size * 256 * 10 * 10 470 | fm10_2 = self.extra_1x1_conv(fm10_2) # batch_size * 256 * 10 * 10 471 | 472 | # depth-wise cross correlation 473 | sm19 = XCross_depthwise(fm19_1, fm19_2, padding=9) # batch_size * 256 * 19 * 19 474 | sm10 = XCross_depthwise(fm10_1, fm10_2, padding=5) # batch_size * 256 * 11 * 11 475 | 476 | # aggregate along the channel 477 | sm19_downsampled = F.adaptive_avg_pool2d(sm19, (sm10.size()[2], sm10.size()[3])) # batch_size * 256 * 11 * 11 478 | out = torch.cat([sm19_downsampled, sm10], 1) # batch_size * 512 * 11 * 11 479 | out = F.relu(out) 480 | 481 | # convolution layers 482 | if self.nconvs in [1, 2, 3]: 483 | out = self.conv_combo_1(out) 484 | if self.nconvs in [2, 3]: 485 | out = self.conv_combo_2(out) 486 | if self.nconvs == 3: 487 | out = self.conv_combo_3(out) 488 | 489 | # global average pooling 490 | out = F.adaptive_avg_pool2d(out, (1, 1)) 491 | out = out.view(out.size(0), -1) 492 | out = self.fc(out) 493 | return out 494 | 495 | 496 | class sn_depthwise_cc_layerwise_last3(nn.Module): 497 | """ 498 | Depthwise cross-correlation, layerwise aggregation (aggregate similarity maps from 3 layers), with 3x3 convolution kernel. 499 | """ 500 | def __init__(self, nconvs=2, nfilters=256): 501 | super(sn_depthwise_cc_layerwise_last3, self).__init__() 502 | self.net1 = resnet18_modified_3(BasicBlock, [2, 2, 2, 2]) 503 | self.net1.fc = nn.Linear(256, 128) 504 | 505 | self.extra_1x1_conv_2 = nn.Conv2d(256, 128, kernel_size=1, stride=1, padding=0, bias=False) 506 | self.extra_1x1_conv_3 = nn.Conv2d(512, 128, kernel_size=1, stride=1, padding=0, bias=False) 507 | 508 | self.nconvs = nconvs 509 | assert self.nconvs in [0, 1, 2, 3] 510 | 511 | if self.nconvs in [1, 2, 3]: 512 | self.conv_combo_1 = nn.Sequential(nn.Conv2d( 513 | 384, nfilters, kernel_size=3, padding=1, stride=1), nn.BatchNorm2d(nfilters), nn.ReLU(inplace=True)) 514 | 515 | if self.nconvs in [2, 3]: 516 | self.conv_combo_2 = nn.Sequential(nn.Conv2d( 517 | nfilters, nfilters, kernel_size=3, padding=1, stride=1), nn.BatchNorm2d(nfilters), nn.ReLU(inplace=True)) 518 | 519 | if self.nconvs == 3: 520 | self.conv_combo_3 = nn.Sequential(nn.Conv2d( 521 | nfilters, nfilters, kernel_size=3, padding=1, stride=1), nn.BatchNorm2d(nfilters), nn.ReLU(inplace=True)) 522 | 523 | if self.nconvs == 0: 524 | self.fc = nn.Linear(384, 2) 525 | else: 526 | self.fc = nn.Linear(nfilters, 2) 527 | 528 | def forward(self, img1, img2): 529 | img1 = torch.cat([img1, img1, img1], 1) 530 | img2 = torch.cat([img2, img2, img2], 1) 531 | 532 | _, fm38_1, fm19_1, fm10_1 = self.net1(img1) 533 | _, fm38_2, fm19_2, fm10_2 = self.net1(img2) 534 | 535 | # change the number of channels 536 | fm19_1 = self.extra_1x1_conv_2(fm19_1) # batch_size * 128 * 19 * 19 537 | fm19_2 = self.extra_1x1_conv_2(fm19_2) # batch_size * 128 * 19 * 19 538 | 539 | fm10_1 = self.extra_1x1_conv_3(fm10_1) # batch_size * 128 * 10 * 10 540 | fm10_2 = self.extra_1x1_conv_3(fm10_2) # batch_size * 128 * 10 * 10 541 | 542 | # depth-wise cross correlation 543 | sm38 = XCross_depthwise(fm38_1, fm38_2, padding=19) # batch_size * 128 * 39 * 39 544 | sm19 = XCross_depthwise(fm19_1, fm19_2, padding=9) # batch_size * 128 * 19 * 19 545 | sm10 = XCross_depthwise(fm10_1, fm10_2, padding=5) # batch_size * 128 * 11 * 11 546 | 547 | 548 | # aggregate along the channel 549 | sm38_downsampled = F.adaptive_avg_pool2d(sm38, (sm10.size()[2], sm10.size()[3])) # batch_size * 128 * 11 * 11 550 | sm19_downsampled = F.adaptive_avg_pool2d(sm19, (sm10.size()[2], sm10.size()[3])) # batch_size * 128 * 11 * 11 551 | out = torch.cat([sm38_downsampled, sm19_downsampled, sm10], 1) # batch_size * 384 * 11 * 11 552 | out = F.relu(out) 553 | 554 | # convolution layers 555 | if self.nconvs in [1, 2, 3]: 556 | out = self.conv_combo_1(out) 557 | if self.nconvs in [2, 3]: 558 | out = self.conv_combo_2(out) 559 | if self.nconvs == 3: 560 | out = self.conv_combo_3(out) 561 | 562 | # global average pooling 563 | out = F.adaptive_avg_pool2d(out, (1, 1)) 564 | out = out.view(out.size(0), -1) 565 | out = self.fc(out) 566 | return out -------------------------------------------------------------------------------- /LR_models/siamese_model_rgb.py: -------------------------------------------------------------------------------- 1 | # -*-coding:utf-8 -*- 2 | 3 | from __future__ import print_function 4 | from __future__ import division 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | from torch.utils.data import Dataset, DataLoader 9 | import torchvision 10 | from torchvision import datasets, models, transforms, utils 11 | import torchvision.transforms.functional as TF 12 | 13 | from tqdm import tqdm 14 | import numpy as np 15 | import pandas as pd 16 | import pickle 17 | import matplotlib.pyplot as plt 18 | # import skimage 19 | # import skimage.io 20 | # import skimage.transform 21 | from PIL import Image 22 | import time 23 | import os 24 | from os.path import join, exists 25 | import copy 26 | import random 27 | from collections import OrderedDict 28 | from sklearn.metrics import r2_score 29 | 30 | 31 | import torch.nn.functional as F 32 | from torchvision.models import Inception3, resnet18, ResNet 33 | from collections import namedtuple 34 | 35 | 36 | def conv3x3(in_planes, out_planes, stride=1): 37 | """3x3 convolution with padding""" 38 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 39 | padding=1, bias=False) 40 | 41 | 42 | def conv1x1(in_planes, out_planes, stride=1): 43 | """1x1 convolution""" 44 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 45 | 46 | 47 | class BasicBlock(nn.Module): 48 | expansion = 1 49 | 50 | def __init__(self, inplanes, planes, stride=1, downsample=None): 51 | super(BasicBlock, self).__init__() 52 | self.conv1 = conv3x3(inplanes, planes, stride) 53 | self.bn1 = nn.BatchNorm2d(planes) 54 | self.relu = nn.ReLU(inplace=True) 55 | self.conv2 = conv3x3(planes, planes) 56 | self.bn2 = nn.BatchNorm2d(planes) 57 | self.downsample = downsample 58 | self.stride = stride 59 | 60 | def forward(self, x): 61 | residual = x 62 | 63 | out = self.conv1(x) 64 | out = self.bn1(out) 65 | out = self.relu(out) 66 | 67 | out = self.conv2(out) 68 | out = self.bn2(out) 69 | 70 | if self.downsample is not None: 71 | residual = self.downsample(x) 72 | 73 | out += residual 74 | out = self.relu(out) 75 | 76 | return out 77 | 78 | 79 | class Bottleneck(nn.Module): 80 | expansion = 4 81 | 82 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 83 | base_width=64, dilation=1, norm_layer=None): 84 | super(Bottleneck, self).__init__() 85 | if norm_layer is None: 86 | norm_layer = nn.BatchNorm2d 87 | width = int(planes * (base_width / 64.)) * groups 88 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 89 | self.conv1 = conv1x1(inplanes, width) 90 | self.bn1 = norm_layer(width) 91 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 92 | self.bn2 = norm_layer(width) 93 | self.conv3 = conv1x1(width, planes * self.expansion) 94 | self.bn3 = norm_layer(planes * self.expansion) 95 | self.relu = nn.ReLU(inplace=True) 96 | self.downsample = downsample 97 | self.stride = stride 98 | 99 | def forward(self, x): 100 | identity = x 101 | 102 | out = self.conv1(x) 103 | out = self.bn1(out) 104 | out = self.relu(out) 105 | 106 | out = self.conv2(out) 107 | out = self.bn2(out) 108 | out = self.relu(out) 109 | 110 | out = self.conv3(out) 111 | out = self.bn3(out) 112 | 113 | if self.downsample is not None: 114 | identity = self.downsample(x) 115 | 116 | out += identity 117 | out = self.relu(out) 118 | 119 | return out 120 | 121 | 122 | class resnet_modified(ResNet): 123 | def forward(self, x): 124 | x = self.conv1(x) 125 | x = self.bn1(x) 126 | x = self.relu(x) 127 | x = self.maxpool(x) 128 | 129 | fm75 = self.layer1(x) 130 | # 64 x 75 x 75 131 | fm38 = self.layer2(fm75) 132 | # 128 x 38 x 38 133 | fm19 = self.layer3(fm38) 134 | # 256 x 19 x 19 135 | fm10 = self.layer4(fm19) 136 | # 512 x 10 x 10 137 | # x256 = self.avgpool(fm19) 138 | # x = x256.view(x256.size(0), -1) 139 | 140 | # x512 = self.avgpool(fm10) 141 | # x = x512.view(x512.size(0), -1) 142 | # x = self.fc(x) 143 | 144 | return fm75, fm38, fm19, fm10 145 | 146 | 147 | def XCorr_depthwise(fm1, fm2, padding): 148 | """ 149 | :param fm1: inputs 150 | :param fm2: filters 151 | :param padding: padding in cross correlation layer 152 | :return: similarity map after cross correlation 153 | """ 154 | nchannels = fm1.size()[1] 155 | fm1 = fm1.reshape(-1, fm1.size()[2], fm1.size()[3]) 156 | fm1 = fm1.unsqueeze(0) 157 | fm2 = fm2.reshape(-1, fm2.size()[2], fm2.size()[3]) 158 | fm2 = fm2.unsqueeze(0).permute(1, 0, 2, 3) 159 | out = F.conv2d(fm1, fm2, padding=padding, stride=1, groups=fm2.size()[0]).squeeze() 160 | out = out.reshape(-1, nchannels, out.size()[1], out.size()[2]) 161 | return out 162 | 163 | 164 | class sn_cc_l3(nn.Module): 165 | def __init__(self, hidden=128, backbone='resnet18'): 166 | super(sn_cc_l3, self).__init__() 167 | assert backbone in ['resnet18', 'resnet34', 'resnet50'] 168 | if backbone == 'resnet18': 169 | self.net1 = resnet_modified(BasicBlock, [2, 2, 2, 2]) 170 | elif backbone == 'resnet34': 171 | self.net1 = resnet_modified(BasicBlock, [3, 4, 6, 3]) 172 | elif backbone == 'resnet50': 173 | self.net1 = resnet_modified(Bottleneck, [3, 4, 6, 3]) 174 | else: 175 | raise 176 | self.backbone = backbone 177 | self.nchannels_dict = {'resnet18': {1: 64, 2: 128, 3: 256, 4: 512}, 178 | 'resnet34': {1: 64, 2: 128, 3: 256, 4: 512}, 179 | 'resnet50': {1: 256, 2: 512, 3: 1024, 4: 2048}} 180 | 181 | self.fc1 = nn.Linear(361, hidden) 182 | self.fc2 = nn.Linear(hidden, 2) 183 | 184 | def forward(self, img1, img2): 185 | 186 | _, _, fm1, _ = self.net1(img1) 187 | _, _, fm2, _ = self.net1(img2) 188 | 189 | fm1 = fm1.reshape(-1, fm1.size()[2], fm1.size()[3]) 190 | fm1 = fm1.unsqueeze(0) 191 | new_vec = F.conv2d(fm1, fm2, padding=9, stride=1, groups=fm2.size()[0]).permute(1, 0, 2, 3) 192 | 193 | new_vec = new_vec.view(new_vec.size(0), -1) 194 | out = F.relu(new_vec) 195 | out = self.fc1(out) 196 | out = F.relu(out) 197 | out = self.fc2(out) 198 | return out 199 | 200 | 201 | class psn_cc_l3(nn.Module): 202 | def __init__(self, hidden=128, backbone='resnet18'): 203 | super(psn_cc_l3, self).__init__() 204 | assert backbone in ['resnet18', 'resnet34', 'resnet50'] 205 | if backbone == 'resnet18': 206 | self.net1 = resnet_modified(BasicBlock, [2, 2, 2, 2]) 207 | self.net2 = resnet_modified(BasicBlock, [2, 2, 2, 2]) 208 | elif backbone == 'resnet34': 209 | self.net1 = resnet_modified(BasicBlock, [3, 4, 6, 3]) 210 | self.net2 = resnet_modified(BasicBlock, [3, 4, 6, 3]) 211 | elif backbone == 'resnet50': 212 | self.net1 = resnet_modified(Bottleneck, [3, 4, 6, 3]) 213 | self.net2 = resnet_modified(Bottleneck, [3, 4, 6, 3]) 214 | else: 215 | raise 216 | self.backbone = backbone 217 | self.nchannels_dict = {'resnet18': {1: 64, 2: 128, 3: 256, 4: 512}, 218 | 'resnet34': {1: 64, 2: 128, 3: 256, 4: 512}, 219 | 'resnet50': {1: 256, 2: 512, 3: 1024, 4: 2048}} 220 | 221 | self.fc1 = nn.Linear(361, hidden) 222 | self.fc2 = nn.Linear(hidden, 2) 223 | 224 | def forward(self, img1, img2): 225 | 226 | _, _, fm1, _ = self.net1(img1) 227 | _, _, fm2, _ = self.net2(img2) 228 | 229 | fm1 = fm1.reshape(-1, fm1.size()[2], fm1.size()[3]) 230 | fm1 = fm1.unsqueeze(0) 231 | new_vec = F.conv2d(fm1, fm2, padding=9, stride=1, groups=fm2.size()[0]).permute(1, 0, 2, 3) 232 | 233 | new_vec = new_vec.view(new_vec.size(0), -1) 234 | out = F.relu(new_vec) 235 | out = self.fc1(out) 236 | out = F.relu(out) 237 | out = self.fc2(out) 238 | return out 239 | 240 | 241 | class psn_depthwise_cc_l3(nn.Module): 242 | def __init__(self, backbone='resnet18', nconvs=2, nfilters=256, kernel_size=3): 243 | super(psn_depthwise_cc_l3, self).__init__() 244 | assert backbone in ['resnet18', 'resnet34', 'resnet50'] 245 | if backbone == 'resnet18': 246 | self.net1 = resnet_modified(BasicBlock, [2, 2, 2, 2]) 247 | self.net2 = resnet_modified(BasicBlock, [2, 2, 2, 2]) 248 | elif backbone == 'resnet34': 249 | self.net1 = resnet_modified(BasicBlock, [3, 4, 6, 3]) 250 | self.net2 = resnet_modified(BasicBlock, [3, 4, 6, 3]) 251 | elif backbone == 'resnet50': 252 | self.net1 = resnet_modified(Bottleneck, [3, 4, 6, 3]) 253 | self.net2 = resnet_modified(Bottleneck, [3, 4, 6, 3]) 254 | else: 255 | raise 256 | self.backbone = backbone 257 | self.nchannels_dict = {'resnet18': {1: 64, 2: 128, 3: 256, 4: 512}, 258 | 'resnet34': {1: 64, 2: 128, 3: 256, 4: 512}, 259 | 'resnet50': {1: 256, 2: 512, 3: 1024, 4: 2048}} 260 | 261 | self.nconvs = nconvs 262 | assert self.nconvs in [0, 1, 2, 3] 263 | 264 | self.nchannels = self.nchannels_dict[self.backbone][3] 265 | 266 | if kernel_size == 3: 267 | padding = 1 268 | elif kernel_size == 1: 269 | padding = 0 270 | else: 271 | raise 272 | 273 | if self.nconvs in [1, 2, 3]: 274 | self.conv_combo_1 = nn.Sequential(nn.Conv2d( 275 | self.nchannels, nfilters, kernel_size=kernel_size, padding=padding, stride=1), 276 | nn.BatchNorm2d(nfilters), nn.ReLU(inplace=True)) 277 | 278 | if self.nconvs in [2, 3]: 279 | self.conv_combo_2 = nn.Sequential(nn.Conv2d( 280 | nfilters, nfilters, kernel_size=kernel_size, padding=padding, stride=1), nn.BatchNorm2d(nfilters), 281 | nn.ReLU(inplace=True)) 282 | 283 | if self.nconvs == 3: 284 | self.conv_combo_3 = nn.Sequential(nn.Conv2d( 285 | nfilters, nfilters, kernel_size=kernel_size, padding=padding, stride=1), nn.BatchNorm2d(nfilters), 286 | nn.ReLU(inplace=True)) 287 | 288 | if self.nconvs == 0: 289 | self.fc = nn.Linear(self.nchannels, 2) 290 | else: 291 | self.fc = nn.Linear(nfilters, 2) 292 | 293 | def forward(self, img1, img2): 294 | 295 | _, _, fm1, _ = self.net1(img1) 296 | _, _, fm2, _ = self.net2(img2) 297 | 298 | # depth-wise cross correlation 299 | nchannels = fm1.size()[1] 300 | fm1 = fm1.reshape(-1, fm1.size()[2], fm1.size()[3]) 301 | fm1 = fm1.unsqueeze(0) 302 | fm2 = fm2.reshape(-1, fm2.size()[2], fm2.size()[3]) 303 | fm2 = fm2.unsqueeze(0).permute(1, 0, 2, 3) 304 | new_vec = F.conv2d(fm1, fm2, padding=9, stride=1, groups=fm2.size()[0]).squeeze() 305 | new_vec = new_vec.reshape(-1, nchannels, new_vec.size()[1], new_vec.size()[2]) 306 | out = F.relu(new_vec) # batch_size * nchannels * H * W 307 | 308 | # convolution layers 309 | if self.nconvs in [1, 2, 3]: 310 | out = self.conv_combo_1(out) 311 | if self.nconvs in [2, 3]: 312 | out = self.conv_combo_2(out) 313 | if self.nconvs == 3: 314 | out = self.conv_combo_3(out) 315 | 316 | # global average pooling 317 | out = F.adaptive_avg_pool2d(out, (1, 1)) 318 | out = out.view(out.size(0), -1) 319 | out = self.fc(out) 320 | return out 321 | 322 | 323 | class psn_depthwise_cc_layerwise_l23(nn.Module): 324 | def __init__(self, backbone='resnet18', nconvs=2, depth=128, nfilters=256, kernel_size=3): 325 | super(psn_depthwise_cc_layerwise_l23, self).__init__() 326 | assert backbone in ['resnet18', 'resnet34', 'resnet50'] 327 | if backbone == 'resnet18': 328 | self.net1 = resnet_modified(BasicBlock, [2, 2, 2, 2]) 329 | self.net2 = resnet_modified(BasicBlock, [2, 2, 2, 2]) 330 | elif backbone == 'resnet34': 331 | self.net1 = resnet_modified(BasicBlock, [3, 4, 6, 3]) 332 | self.net2 = resnet_modified(BasicBlock, [3, 4, 6, 3]) 333 | elif backbone == 'resnet50': 334 | self.net1 = resnet_modified(Bottleneck, [3, 4, 6, 3]) 335 | self.net2 = resnet_modified(Bottleneck, [3, 4, 6, 3]) 336 | else: 337 | raise 338 | self.backbone = backbone 339 | self.nchannels_dict = {'resnet18': {1: 64, 2: 128, 3: 256, 4: 512}, 340 | 'resnet34': {1: 64, 2: 128, 3: 256, 4: 512}, 341 | 'resnet50': {1: 256, 2: 512, 3: 1024, 4: 2048}} 342 | 343 | self.extra_1x1_conv_1 = nn.Conv2d(self.nchannels_dict[self.backbone][2], depth, kernel_size=1, stride=1, 344 | padding=0, bias=False) 345 | self.extra_1x1_conv_2 = nn.Conv2d(self.nchannels_dict[self.backbone][3], depth, kernel_size=1, stride=1, 346 | padding=0, bias=False) 347 | 348 | self.depth = depth 349 | self.nconvs = nconvs 350 | assert self.nconvs in [0, 1, 2, 3] 351 | 352 | if kernel_size == 3: 353 | padding = 1 354 | elif kernel_size == 1: 355 | padding = 0 356 | else: 357 | raise 358 | 359 | if self.nconvs in [1, 2, 3]: 360 | self.conv_combo_1 = nn.Sequential(nn.Conv2d( 361 | depth*2, nfilters, kernel_size=kernel_size, padding=padding, stride=1), nn.BatchNorm2d(nfilters), nn.ReLU(inplace=True)) 362 | 363 | if self.nconvs in [2, 3]: 364 | self.conv_combo_2 = nn.Sequential(nn.Conv2d( 365 | nfilters, nfilters, kernel_size=kernel_size, padding=padding, stride=1), nn.BatchNorm2d(nfilters), nn.ReLU(inplace=True)) 366 | 367 | if self.nconvs == 3: 368 | self.conv_combo_3 = nn.Sequential(nn.Conv2d( 369 | nfilters, nfilters, kernel_size=kernel_size, padding=padding, stride=1), nn.BatchNorm2d(nfilters), nn.ReLU(inplace=True)) 370 | 371 | if self.nconvs == 0: 372 | self.fc = nn.Linear(depth*2, 2) 373 | else: 374 | self.fc = nn.Linear(nfilters, 2) 375 | 376 | def forward(self, img1, img2): 377 | 378 | _, fm38_1, fm19_1, _ = self.net1(img1) 379 | _, fm38_2, fm19_2, _ = self.net2(img2) 380 | 381 | # change the number of channels 382 | fm38_1 = self.extra_1x1_conv_1(fm38_1) # batch_size * depth * 38 * 38 383 | fm38_2 = self.extra_1x1_conv_1(fm38_2) # batch_size * depth * 38 * 38 384 | 385 | fm19_1 = self.extra_1x1_conv_2(fm19_1) # batch_size * depth * 19 * 19 386 | fm19_2 = self.extra_1x1_conv_2(fm19_2) # batch_size * depth * 19 * 19 387 | 388 | # depth-wise cross correlation 389 | sm38 = XCorr_depthwise(fm38_1, fm38_2, padding=19) # batch_size * depth * 38 * 38 390 | sm19 = XCorr_depthwise(fm19_1, fm19_2, padding=9) # batch_size * depth * 19 * 19 391 | 392 | # aggregate along the channel 393 | sm38_downsampled = F.adaptive_avg_pool2d(sm38, (sm19.size()[2], sm19.size()[3])) # batch_size * depth * 19 * 19 394 | out = torch.cat([sm38_downsampled, sm19], 1) # batch_size * (2*depth) * 19 * 19 395 | out = F.relu(out) 396 | 397 | # convolution layers 398 | if self.nconvs in [1, 2, 3]: 399 | out = self.conv_combo_1(out) 400 | if self.nconvs in [2, 3]: 401 | out = self.conv_combo_2(out) 402 | if self.nconvs == 3: 403 | out = self.conv_combo_3(out) 404 | 405 | # global average pooling 406 | out = F.adaptive_avg_pool2d(out, (1, 1)) 407 | out = out.view(out.size(0), -1) 408 | out = self.fc(out) 409 | return out 410 | 411 | 412 | class psn_depthwise_cc_layerwise_l34(nn.Module): 413 | def __init__(self, backbone='resnet18', nconvs=2, depth=128, nfilters=256, kernel_size=3): 414 | super(psn_depthwise_cc_layerwise_l34, self).__init__() 415 | assert backbone in ['resnet18', 'resnet34', 'resnet50'] 416 | if backbone == 'resnet18': 417 | self.net1 = resnet_modified(BasicBlock, [2, 2, 2, 2]) 418 | self.net2 = resnet_modified(BasicBlock, [2, 2, 2, 2]) 419 | elif backbone == 'resnet34': 420 | self.net1 = resnet_modified(BasicBlock, [3, 4, 6, 3]) 421 | self.net2 = resnet_modified(BasicBlock, [3, 4, 6, 3]) 422 | elif backbone == 'resnet50': 423 | self.net1 = resnet_modified(Bottleneck, [3, 4, 6, 3]) 424 | self.net2 = resnet_modified(Bottleneck, [3, 4, 6, 3]) 425 | else: 426 | raise 427 | self.backbone = backbone 428 | self.nchannels_dict = {'resnet18': {1: 64, 2: 128, 3: 256, 4: 512}, 429 | 'resnet34': {1: 64, 2: 128, 3: 256, 4: 512}, 430 | 'resnet50': {1: 256, 2: 512, 3: 1024, 4: 2048}} 431 | 432 | self.extra_1x1_conv_1 = nn.Conv2d(self.nchannels_dict[self.backbone][3], depth, kernel_size=1, stride=1, 433 | padding=0, bias=False) 434 | self.extra_1x1_conv_2 = nn.Conv2d(self.nchannels_dict[self.backbone][4], depth, kernel_size=1, stride=1, 435 | padding=0, bias=False) 436 | 437 | self.depth = depth 438 | self.nconvs = nconvs 439 | assert self.nconvs in [0, 1, 2, 3] 440 | 441 | if kernel_size == 3: 442 | padding = 1 443 | elif kernel_size == 1: 444 | padding = 0 445 | else: 446 | raise 447 | 448 | if self.nconvs in [1, 2, 3]: 449 | self.conv_combo_1 = nn.Sequential(nn.Conv2d( 450 | depth * 2, nfilters, kernel_size=kernel_size, padding=padding, stride=1), nn.BatchNorm2d(nfilters), 451 | nn.ReLU(inplace=True)) 452 | 453 | if self.nconvs in [2, 3]: 454 | self.conv_combo_2 = nn.Sequential(nn.Conv2d( 455 | nfilters, nfilters, kernel_size=kernel_size, padding=padding, stride=1), nn.BatchNorm2d(nfilters), 456 | nn.ReLU(inplace=True)) 457 | 458 | if self.nconvs == 3: 459 | self.conv_combo_3 = nn.Sequential(nn.Conv2d( 460 | nfilters, nfilters, kernel_size=kernel_size, padding=padding, stride=1), nn.BatchNorm2d(nfilters), 461 | nn.ReLU(inplace=True)) 462 | 463 | if self.nconvs == 0: 464 | self.fc = nn.Linear(depth * 2, 2) 465 | else: 466 | self.fc = nn.Linear(nfilters, 2) 467 | 468 | def forward(self, img1, img2): 469 | 470 | _, _, fm19_1, fm10_1 = self.net1(img1) 471 | _, _, fm19_2, fm10_2 = self.net2(img2) 472 | 473 | # change the number of channels 474 | fm19_1 = self.extra_1x1_conv_1(fm19_1) # batch_size * depth * 19 * 19 475 | fm19_2 = self.extra_1x1_conv_1(fm19_2) # batch_size * depth * 19 * 19 476 | 477 | fm10_1 = self.extra_1x1_conv_2(fm10_1) # batch_size * depth * 10 * 10 478 | fm10_2 = self.extra_1x1_conv_2(fm10_2) # batch_size * depth * 10 * 10 479 | 480 | # depth-wise cross correlation 481 | sm19 = XCorr_depthwise(fm19_1, fm19_2, padding=9) # batch_size * depth * 19 * 19 482 | sm10 = XCorr_depthwise(fm10_1, fm10_2, padding=5) # batch_size * depth * 11 * 11 483 | 484 | # aggregate along the channel 485 | sm19_downsampled = F.adaptive_avg_pool2d(sm19, (sm10.size()[2], sm10.size()[3])) # batch_size * depth * 11 * 11 486 | out = torch.cat([sm19_downsampled, sm10], 1) # batch_size * (2*depth) * 11 * 11 487 | out = F.relu(out) 488 | 489 | # convolution layers 490 | if self.nconvs in [1, 2, 3]: 491 | out = self.conv_combo_1(out) 492 | if self.nconvs in [2, 3]: 493 | out = self.conv_combo_2(out) 494 | if self.nconvs == 3: 495 | out = self.conv_combo_3(out) 496 | 497 | # global average pooling 498 | out = F.adaptive_avg_pool2d(out, (1, 1)) 499 | out = out.view(out.size(0), -1) 500 | out = self.fc(out) 501 | return out 502 | 503 | 504 | class psn_depthwise_cc_layerwise_3layers_l234(nn.Module): 505 | """ 506 | Depthwise cross-correlation, layerwise aggregation (aggregate similarity maps from 3 layers). 507 | """ 508 | def __init__(self, backbone='resnet18', nconvs=2, depth=128, nfilters=256, kernel_size=3): 509 | super(psn_depthwise_cc_layerwise_3layers_l234, self).__init__() 510 | assert backbone in ['resnet18', 'resnet34', 'resnet50'] 511 | if backbone == 'resnet18': 512 | self.net1 = resnet_modified(BasicBlock, [2, 2, 2, 2]) 513 | self.net2 = resnet_modified(BasicBlock, [2, 2, 2, 2]) 514 | elif backbone == 'resnet34': 515 | self.net1 = resnet_modified(BasicBlock, [3, 4, 6, 3]) 516 | self.net2 = resnet_modified(BasicBlock, [3, 4, 6, 3]) 517 | elif backbone == 'resnet50': 518 | self.net1 = resnet_modified(Bottleneck, [3, 4, 6, 3]) 519 | self.net2 = resnet_modified(Bottleneck, [3, 4, 6, 3]) 520 | else: 521 | raise 522 | self.backbone = backbone 523 | self.nchannels_dict = {'resnet18': {1: 64, 2: 128, 3: 256, 4: 512}, 524 | 'resnet34': {1: 64, 2: 128, 3: 256, 4: 512}, 525 | 'resnet50': {1: 256, 2: 512, 3: 1024, 4: 2048}} 526 | 527 | self.extra_1x1_conv_1 = nn.Conv2d(self.nchannels_dict[self.backbone][2], depth, kernel_size=1, stride=1, padding=0, bias=False) 528 | self.extra_1x1_conv_2 = nn.Conv2d(self.nchannels_dict[self.backbone][3], depth, kernel_size=1, stride=1, padding=0, bias=False) 529 | self.extra_1x1_conv_3 = nn.Conv2d(self.nchannels_dict[self.backbone][4], depth, kernel_size=1, stride=1, padding=0, bias=False) 530 | 531 | self.depth = depth 532 | self.nconvs = nconvs 533 | assert self.nconvs in [0, 1, 2, 3] 534 | 535 | if kernel_size == 3: 536 | padding = 1 537 | elif kernel_size == 1: 538 | padding = 0 539 | else: 540 | raise 541 | 542 | if self.nconvs in [1, 2, 3]: 543 | self.conv_combo_1 = nn.Sequential(nn.Conv2d( 544 | depth*3, nfilters, kernel_size=kernel_size, padding=padding, stride=1), nn.BatchNorm2d(nfilters), nn.ReLU(inplace=True)) 545 | 546 | if self.nconvs in [2, 3]: 547 | self.conv_combo_2 = nn.Sequential(nn.Conv2d( 548 | nfilters, nfilters, kernel_size=kernel_size, padding=padding, stride=1), nn.BatchNorm2d(nfilters), nn.ReLU(inplace=True)) 549 | 550 | if self.nconvs == 3: 551 | self.conv_combo_3 = nn.Sequential(nn.Conv2d( 552 | nfilters, nfilters, kernel_size=kernel_size, padding=padding, stride=1), nn.BatchNorm2d(nfilters), nn.ReLU(inplace=True)) 553 | 554 | if self.nconvs == 0: 555 | self.fc = nn.Linear(depth*3, 2) 556 | else: 557 | self.fc = nn.Linear(nfilters, 2) 558 | 559 | def forward(self, img1, img2): 560 | 561 | _, fm38_1, fm19_1, fm10_1 = self.net1(img1) 562 | _, fm38_2, fm19_2, fm10_2 = self.net2(img2) 563 | 564 | # change the number of channels 565 | fm38_1 = self.extra_1x1_conv_1(fm38_1) # batch_size * depth * 39 * 39 566 | fm38_2 = self.extra_1x1_conv_1(fm38_2) # batch_size * depth * 39 * 39 567 | 568 | fm19_1 = self.extra_1x1_conv_2(fm19_1) # batch_size * depth * 19 * 19 569 | fm19_2 = self.extra_1x1_conv_2(fm19_2) # batch_size * depth * 19 * 19 570 | 571 | fm10_1 = self.extra_1x1_conv_3(fm10_1) # batch_size * depth * 11 * 11 572 | fm10_2 = self.extra_1x1_conv_3(fm10_2) # batch_size * depth * 11 * 11 573 | 574 | # depth-wise cross correlation 575 | sm38 = XCorr_depthwise(fm38_1, fm38_2, padding=19) # batch_size * depth * 39 * 39 576 | sm19 = XCorr_depthwise(fm19_1, fm19_2, padding=9) # batch_size * depth * 19 * 19 577 | sm10 = XCorr_depthwise(fm10_1, fm10_2, padding=5) # batch_size * depth * 11 * 11 578 | 579 | # aggregate along the channel 580 | sm38_downsampled = F.adaptive_avg_pool2d(sm38, (sm10.size()[2], sm10.size()[3])) # batch_size * depth * 11 * 11 581 | sm19_downsampled = F.adaptive_avg_pool2d(sm19, (sm10.size()[2], sm10.size()[3])) # batch_size * depth * 11 * 11 582 | out = torch.cat([sm38_downsampled, sm19_downsampled, sm10], 1) # batch_size * (3*depth) * 11 * 11 583 | out = F.relu(out) 584 | 585 | # convolution layers 586 | if self.nconvs in [1, 2, 3]: 587 | out = self.conv_combo_1(out) 588 | if self.nconvs in [2, 3]: 589 | out = self.conv_combo_2(out) 590 | if self.nconvs == 3: 591 | out = self.conv_combo_3(out) 592 | 593 | # global average pooling 594 | out = F.adaptive_avg_pool2d(out, (1, 1)) 595 | out = out.view(out.size(0), -1) 596 | out = self.fc(out) 597 | return out 598 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DeepSolar++: Understanding Residential Solar Adoption Trajectories with Computer Vision and Technology Diffusion Model 2 | 3 | A deep learning framework to detect solar PV installations from historical satellite/aerial images and predict the installation year of PV. The model is applied to different places across the U.S. for uncovering solar adoption trajectories across time. The heterogeneity in solar adoption trajectories is further analyzed from the perspective of technology diffusion model. 4 | 5 | To use the code, please cite the [paper](https://www.cell.com/joule/pdf/S2542-4351(22)00477-9.pdf): 6 | 7 | * Wang, Z., Arlt, M. L., Zanocco, C., Majumdar, A., & Rajagopal, R. (2022). DeepSolar++: Understanding residential solar adoption trajectories with computer vision and technology diffusion models. Joule, 6(11), 2611-2625. DOI: https://doi.org/10.1016/j.joule.2022.09.011 8 | 9 | The operating system for developing this code repo is Ubuntu 16.04, but it should also be able to run in other environments. The Python version used for developing this code repo is Python 3.6. 10 | 11 | ## Public data access 12 | 13 | The census-block-group level time-series residential solar installation dataset can be accessed here: 14 | 15 | * 420 counties (which are analyzed in the [DeepSolar++](https://www.cell.com/joule/pdf/S2542-4351(22)00477-9.pdf) paper): [CSV file data downloading link](https://opendatasharing.s3.us-west-2.amazonaws.com/DeepSolar2/data/residential_solar_installations_panel_data_420counties.csv) 16 | 17 | * All counties in the contiguous US: [CSV file data downloading link](https://opendatasharing.s3.us-west-2.amazonaws.com/DeepSolar2/data/residential_solar_installations_panel_data_contiguous_US.csv). 18 | 19 | In the CSV file, each census block group is corresponding to a row identified by its FIPS code (column 'blockgroup_FIPS'). Foe each block group, the table contains its cumulative number of residential PV installations in each year from 2005 to 2017, average GHI, building count, and ACS demographic data. 20 | 21 | For the installation information from 2017 to 2022, we are still working on updating the data using the same methodology and will release it at some time point in 2023. 22 | 23 | ## Install required packages 24 | 25 | Run the following command line: 26 | 27 | ``` 28 | $ pip install -r requirements.txt 29 | ``` 30 | 31 | **Note**: multivariate OLS and logit regressions are run by the R code blocks inserted in the Python Jupyter Notework `bass_model_parameter_regression.ipynb`. It is based on the `rpy2` package. If you want to run these regressions, R and R kernel are required to install. Moreover, `lmtest` and `sandwich` are required libraries for R (which may need to be installed). 32 | For further details about installing R and R kernel, see this [tutorial](https://linuxize.com/post/how-to-install-r-on-ubuntu-20-04/) and this [tutorial](https://datatofish.com/r-jupyter-notebook/). For further details about using R in Python notebook, see [this](https://stackoverflow.com/questions/39008069/r-and-python-in-one-jupyter-notebook). 33 | 34 | ## Download data and model checkpoints 35 | 36 | Run the following command lines to download the ZIP files right under the code repo directory: 37 | 38 | ``` 39 | $ curl -O https://opendatasharing.s3.us-west-2.amazonaws.com/DeepSolar2/checkpoint.zip 40 | $ curl -O https://opendatasharing.s3.us-west-2.amazonaws.com/DeepSolar2/data.zip 41 | $ curl -O https://opendatasharing.s3.us-west-2.amazonaws.com/DeepSolar2/results.zip 42 | ``` 43 | 44 | Unzip them such that the directory structure looks like: 45 | 46 | ``` 47 | DeepSolar_timelapse/checkpoint/... 48 | DeepSolar_timelapse/data/... 49 | DeepSolar_timelapse/results/... 50 | ``` 51 | 52 | **Note 1**: for the satellite/aerial imagery datasets under `data` directory (subdirectory `HR_images` for high-resolution (HR) images, `LR_images` for low-resolution (LR) images, `blur_detection_images` for blur detection images, and `sequences` for image sequences), due to the restriction of imagery data sources, we are not able to publicly share the full data. Instead, for each subset (training/validation/test) and each class (e.g., positive/negative), we share two example images as a demo. For the image sequence dataset (`sequences`), we share one demo sequence (`sequences/demo_sequences/1`). Each image sequence contains satellite/aerial images captured in different years at the same location of a solar installation (image file name examples: `2006_0.png`, `2007_0.png`, `2007_1.png`, `2008_0.png`, etc). Users can put their own data under these directories. 53 | 54 | **Note 2**: to run Jupyter Notebook, the default kernel/environment is "conda_tensorflow_p36", which does not necessarily exist in your computer. Please change the kernel to the one where all required packages are installed. 55 | 56 | ## Functionality of each script/notebook 57 | 58 | ### Part 1: model training with hyperparameter search 59 | 60 | An image is first classfied by the blur detection model into one of the three classes according to its resolution: high resolution (HR), low resolution (LR), and extreme blurred/out of distribution (OOD). An OOD image is not used for determining the solar installation year; a HR image is classified by a single-branch CNN into two classes: positive (containing solar PV) and negative (otherwise); a LR image is classified by a two-branch Siamese CNN into two classes: positive (containing solar PV) and negative (otherwise). 61 | 62 | For training the blur detection model with hyperparameter search: 63 | ``` 64 | $ python hp_search_ood_multilabels.py 65 | ``` 66 | For training the HR model with hyperparameter search: 67 | ``` 68 | $ python hp_search_HR.py 69 | ``` 70 | For training the LR model with hyperparameter search: 71 | ``` 72 | $ python hp_search_LR_rgb.py 73 | ``` 74 | 75 | By default, all three scripts above are run on a machine with GPU. 76 | 77 | ### Part 2: deploying models to predict installation year 78 | 79 | For each solar installation, we can retrieve a sequence of images captured in different years at its location and put them in the same folder. The images are named as `{image_capture_year}_{auxiliary_index}.png`. For example, if there are three images captured in 2012, they are named as `2012_0.png`, `2012_1.png`, and `2012_2.png`, respectively. 80 | 81 | For each image sequence, we deploy the blur detection model, HR model, and LR model. Their model outputs are combined together predict the installation year of the solar PV system. 82 | 83 | First, we deploy the blur detection model and HR model to image sequences: 84 | ``` 85 | $ python predict_HR.py 86 | ``` 87 | ``` 88 | $ python predict_ood_multilabels.py 89 | ``` 90 | 91 | Combining the prediction results of the above two models, we can generate the "anchor_images_dict" that maps a target image in a sequence to all its reference images in this sequence. This needs to be run before deploying the LR model, as the LR model needs to take a target image and one of its corresponding reference image as inputs. To do this, run the code blocks in this Jupyter Notebook: 92 | 93 | ``` 94 | generate_anchor_image_dict.ipynb 95 | ``` 96 | 97 | Then, deploy the LR model to image sequences: 98 | ``` 99 | $ python predict_LR_rgb.py 100 | ``` 101 | 102 | By default, all three `.py` scripts above are run on a machine with GPU. 103 | 104 | Finally, run the code blocks in this Jupyter Notebook that combines all model prediction outputs to predict the installation year for each solar PV system: 105 | 106 | ``` 107 | predict_installation_year_from_image_sequences.ipynb 108 | ``` 109 | 110 | ### Part 3: analyzing solar adoption trajectories across time using Bass model 111 | 112 | By predicting the installation year of each system, we are able to obtain the number of installations in each year in each place. We provide such solar installation time-series dataframe at the census block group level covering the randomly-sampled 420 counties in ``results``. This dataframe is used as inputs for the following analysis. 113 | 114 | For the solar adoption trajectory in each block group, we use a classical technology adoption model, [Bass model](https://pubsonline.informs.org/doi/abs/10.1287/mnsc.15.5.215?casa_token=PXhDNyJRVhgAAAAA:ZbFnu9tpKcAJoUDE6JlpMyWvaaa0hyXeuFA2Edbg8EORBlPTUVHBWShq6c1yuA5SBaPBRyLCW1Q), to parameterize its shape. First, the Bass model fitting based on Non-linear least squares (NLS): 115 | 116 | ``` 117 | $ python bass_model_fit_adoption_curves.py 118 | ``` 119 | 120 | Then we can use two Jupyter Notebook to analyze the Bass model parameters that have been fitted. They can be run without running the Bass model fitting code as the intermediate result of model fitting has already been provided in `results`: 121 | 122 | `bass_model_parameter_phase_analysis.ipynb`: Based on the fitted Bass model parameters, segment each solar adoption trajectory into four phases: pre-diffusion, ramp-up, ramp-down, and saturation, and analyze the fractions of block groups in each phase. 123 | 124 | `bass_model_parameter_regression.ipynb`: Run multivariate regressions with various Bass model parameters as dependent variables and socioconomic characteristics (demographics, PV benefit, the presence of incentives, etc) as independent variables. R code blocks are inserted for running these regressions. 125 | 126 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangzhecheng/DeepSolar_timelapse/b9f0e0e81b901426a166f55a3ba6ab53ea260407/__init__.py -------------------------------------------------------------------------------- /bass_model_fit_adoption_curves.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import random 4 | import os 5 | from os.path import join, exists 6 | import pickle 7 | from tqdm import tqdm 8 | import matplotlib.pyplot as plt 9 | from matplotlib import patches 10 | from copy import deepcopy 11 | from utils import * 12 | import statsmodels.api as sm 13 | import statsmodels.formula.api as smf 14 | from sklearn.linear_model import LinearRegression 15 | from sklearn.metrics import r2_score 16 | from scipy import stats 17 | from scipy.optimize import least_squares 18 | 19 | """ 20 | This script is for fitting the Bass model curve for the solar adoption curve in each census block 21 | group. For each block group, there will be a set of estimated Bass model parameters, include d, p, 22 | q, m. 23 | """ 24 | price_trend = np.array([10.0, 10.1, 9.6, 9.2, 7.9, 7.0, 5.9, 5.1, 4.8, 4.7, 4.4]) 25 | log_price_trend = np.log(price_trend / price_trend[0]) 26 | 27 | def get_cum_installations(coefs, t): 28 | """Given Bass model coefficients, and a time array, return the array of cumulative values.""" 29 | p, q, m, d = coefs 30 | y = (1 - np.exp(-(p + q) * (t - d))) / (1 + q / p * np.exp(-(p + q) * (t - d))) 31 | return m * y * (t >= d) + 0 * (t < d) 32 | 33 | def get_residuals_cum(coefs, t, y_true): 34 | """ 35 | Given Bass model coefficients, and a time array, return resididual between true cumulative 36 | values and predicted cumulative values (predicted by Bass model). 37 | """ 38 | y_pred = get_cum_installations(coefs, t) 39 | return y_pred - y_true 40 | 41 | def get_cum_installations_GBM(coefs, t): 42 | """ 43 | Given Generalized Bass model coefficients, and a time array, return the array of 44 | cumulative values. 45 | """ 46 | p, q, m, d, beta = coefs 47 | y = (1 - np.exp(-(p + q) * (t - d + beta * log_price_trend))) / (1 + q / p * np.exp(-(p + q) * (t - d + beta * log_price_trend))) 48 | return m * y * (t >= d) + 0 * (t < d) 49 | 50 | def get_residuals_cum_GBM(coefs, t, y_true): 51 | """ 52 | Given Generalized Bass model coefficients, and a time array, return resididual between true 53 | cumulative values and predicted cumulative values (predicted by Bass model). 54 | """ 55 | y_pred = get_cum_installations_GBM(coefs, t) 56 | return y_pred - y_true 57 | 58 | def get_best_fit_NLS(y_arr, dy_arr, initial_p=0.01, initial_q=0.5, initial_m=None, upper_m=np.inf): 59 | """ 60 | Given the arrays of cumulative installations in previous year and new installations in current year, 61 | fit the Bass model curve using Non-Linear Least Square (NLS) with initial values and bounds. 62 | Different initial values of d (onset of adoption) will be tried to get its best initial value. 63 | Args: 64 | y_arr: the array of cumulative installations in previous year. 65 | dy_arr: the array of new installations in present year. 66 | initial_p: initial value of p. 67 | initial_q: initial value of q. 68 | initial_m: initial value of m. 69 | upper_m: the upper bound of m. 70 | Return: 71 | min_rmse: the minimum RMSE of fitting. 72 | best_params: the best set of Bass model parameters (p, q, m, d). 73 | best_y_pred: the prediction of cumulative installations under the best parameters. 74 | best_dy_pred: the prediction of yearly new installations under the best parameters. 75 | best_onset_idx: the best initial value of d that will yield the minimum RMSE. 76 | """ 77 | y_true = np.array(list(y_arr[1:]) + [y_arr[-1] + dy_arr[-1]]) 78 | if initial_m is None: 79 | initial_m = y_true[-1] 80 | min_rmse = np.inf 81 | best_params = None 82 | best_y_pred = None 83 | best_onset_idx = None 84 | for onset_idx in range(-5, 11): 85 | ls_model = least_squares( 86 | get_residuals_cum, 87 | x0=np.array([initial_p, initial_q, initial_m, onset_idx]), 88 | jac='cs', 89 | bounds=([0, 0, 0, -8], [1, np.inf, upper_m, np.inf]), 90 | args=(np.arange(0, 11), y_true), 91 | method='trf', 92 | ) 93 | y_pred = get_cum_installations(ls_model.x, np.arange(0, 11)) 94 | rmse = np.sqrt(np.sum((y_pred - y_true) ** 2)) 95 | p, q, m, d = ls_model.x 96 | if rmse < min_rmse and p > 0 and q > 0 and m > 0: 97 | min_rmse = rmse 98 | best_params = ls_model.x 99 | best_y_pred = y_pred 100 | best_onset_idx = onset_idx 101 | best_dy_pred = np.concatenate([[best_y_pred[0]], best_y_pred[1:] - best_y_pred[:-1]]) 102 | return min_rmse, best_params, best_y_pred, best_dy_pred, best_onset_idx 103 | 104 | def get_best_fit_NLS_GBM(y_arr, dy_arr, initial_p=0.01, initial_q=0.5, initial_m=None, initial_beta=-0.4): 105 | """ 106 | Given the arrays of cumulative installations in previous year and new installations in current year, 107 | fit the Generalized Bass model (GBM) curve using Non-Linear Least Square (NLS) with initial values and bounds. 108 | Different initial values of d (onset of adoption) will be tried to get its best initial value. 109 | Args: 110 | y_arr: the array of cumulative installations in previous year. 111 | dy_arr: the array of new installations in present year. 112 | initial_p: initial value of p. 113 | initial_q: initial value of q. 114 | initial_m: initial value of m. 115 | upper_m: the upper bound of m. 116 | unitial_beta: the initial value of beta. 117 | Return: 118 | min_rmse: the minimum RMSE of fitting. 119 | best_params: the best set of Bass model parameters (p, q, m, d, beta). 120 | best_y_pred: the prediction of cumulative installations under the best parameters. 121 | best_dy_pred: the prediction of yearly new installations under the best parameters. 122 | best_onset_idx: the best initial value of d that will yield the minimum RMSE. 123 | """ 124 | y_true = np.array(list(y_arr[1:]) + [y_arr[-1] + dy_arr[-1]]) 125 | if initial_m is None: 126 | initial_m = y_true[-1] 127 | min_rmse = np.inf 128 | best_params = None 129 | best_y_pred = None 130 | best_onset_idx = None 131 | for onset_idx in range(-5, 11): 132 | ls_model = least_squares( 133 | get_residuals_cum_GBM, 134 | x0=np.array([initial_p, initial_q, initial_m, onset_idx, initial_beta]), 135 | jac='cs', 136 | bounds=([0, 0, 0, -np.inf, -np.inf], [1, np.inf, np.inf, np.inf, 0]), 137 | args=(np.arange(0, 11), y_true), 138 | method='trf', 139 | ) 140 | y_pred = get_cum_installations_GBM(ls_model.x, np.arange(0, 11)) 141 | rmse = np.sqrt(np.sum((y_pred - y_true) ** 2)) 142 | p, q, m, d, beta = ls_model.x 143 | if rmse < min_rmse and p > 0 and q > 0 and m > 0: 144 | min_rmse = rmse 145 | best_params = ls_model.x 146 | best_y_pred = y_pred 147 | best_onset_idx = onset_idx 148 | best_dy_pred = np.concatenate([[best_y_pred[0]], best_y_pred[1:] - best_y_pred[:-1]]) 149 | return min_rmse, best_params, best_y_pred, best_dy_pred, best_onset_idx 150 | 151 | # 1. Load and process data 152 | bg = pd.read_csv('results/merged_bg.csv') 153 | 154 | bg = bg[['blockgroup_FIPS', 'year', 'num_of_installations', 'num_of_buildings_lt600']] 155 | 156 | bg['tract_FIPS'] = bg['blockgroup_FIPS'] // 10 157 | bg['county_FIPS'] = bg['blockgroup_FIPS'] // 10000000 158 | bg['state_FIPS'] = bg['blockgroup_FIPS'] // 10000000000 159 | 160 | df = bg.sort_values(['year', 'blockgroup_FIPS']) 161 | 162 | cumulative_pv_count_dict = {} 163 | cumulative_pv_count = np.array(df[df['year'] == 2005]['num_of_installations']) 164 | cumulative_pv_count_dict[2005] = cumulative_pv_count.copy() 165 | for year in range(2006, 2018): 166 | cumulative_pv_count = cumulative_pv_count + np.array(df[df['year'] == year]['num_of_installations']) 167 | cumulative_pv_count_dict[year] = cumulative_pv_count.copy() 168 | 169 | df['cum_num_of_installations'] = np.concatenate([cumulative_pv_count_dict[x] for x in cumulative_pv_count_dict]) 170 | df.sort_values(['blockgroup_FIPS', 'year'], inplace=True) 171 | 172 | prev_year_cum = df[(df['year'] <= 2015) & (df['year'] >= 2005)][['blockgroup_FIPS', 173 | 'year', 174 | 'cum_num_of_installations']] 175 | # prev_year_cum.index = prev_year_cum['blockgroup_FIPS'] 176 | prev_year_cum.rename(columns={'cum_num_of_installations': 'cum_num_of_installations_prev'}, inplace=True) 177 | prev_year_cum['year'] = prev_year_cum['year'] + 1 178 | 179 | df_sub = df[(df['year'] <= 2016) & (df['year'] >= 2006)] 180 | df_sub = pd.merge(df_sub, prev_year_cum, how='left', on=['blockgroup_FIPS', 'year']) 181 | 182 | adoption_matrix_bg = pd.DataFrame(df_sub['num_of_installations'].to_numpy().reshape([-1, 11])) 183 | adoption_matrix_bg.index = df_sub[df_sub['year'] == 2016]['blockgroup_FIPS'] 184 | adoption_matrix_bg.columns = [str(x) for x in range(2006, 2017)] 185 | 186 | cum_adoption_matrix_prev_bg = pd.DataFrame(df_sub['cum_num_of_installations_prev'].to_numpy().reshape([-1, 11])) 187 | cum_adoption_matrix_prev_bg.index = df_sub[df_sub['year'] == 2016]['blockgroup_FIPS'] 188 | cum_adoption_matrix_prev_bg.columns = [str(x) for x in range(2006, 2017)] 189 | 190 | df_buildings = df[df['year'] == 2016][['blockgroup_FIPS', 'num_of_buildings_lt600']] 191 | df_buildings.index = df_buildings['blockgroup_FIPS'] 192 | 193 | del df 194 | 195 | # Normal Bass Model: Block group level curve fitting 196 | adoption_params_dict = {} 197 | i = 0 198 | for bfips in tqdm(cum_adoption_matrix_prev_bg.index): 199 | i += 1 200 | y_arr = cum_adoption_matrix_prev_bg.loc[bfips, :].to_numpy() 201 | dy_arr = adoption_matrix_bg.loc[bfips, :].to_numpy() 202 | if np.any(dy_arr > 0): 203 | base = df_buildings.loc[bfips, 'num_of_buildings_lt600'] 204 | base = max(base, 1.0) 205 | y_true = np.array(list(y_arr[1:]) + [y_arr[-1] + dy_arr[-1]]) 206 | rmse, (p, q, m, d), y_pred, dy_pred, best_onset_idx = get_best_fit_NLS(y_arr, dy_arr, 207 | initial_m=min(base, y_true[-1]), 208 | upper_m=base) # BM 209 | r2 = r2_score(y_true, y_pred) 210 | else: 211 | p, q, m, d, rmse, r2 = 0, 0, y_arr[0], None, 0., 1. 212 | adoption_params_dict[bfips] = [p, q, m, d, rmse, r2] 213 | if i % 5000 == 0: 214 | adoption_params_bg = pd.DataFrame(adoption_params_dict).transpose() 215 | adoption_params_bg.columns = ['p', 'q', 'm', 'd', 'rmse', 'r2'] 216 | adoption_params_bg.to_csv('results/bass_model/adoption_bass_model_params_bg.csv') 217 | 218 | adoption_params_bg = pd.DataFrame(adoption_params_dict).transpose() 219 | adoption_params_bg.columns = ['p', 'q', 'm', 'd', 'rmse', 'r2'] 220 | adoption_params_bg.to_csv('results/bass_model/adoption_bass_model_params_bg.csv') 221 | -------------------------------------------------------------------------------- /generate_anchor_image_dict.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os\n", 10 | "import os.path\n", 11 | "import pickle\n", 12 | "import shutil\n", 13 | "import pandas as pd\n", 14 | "import random\n", 15 | "import math\n", 16 | "import matplotlib.pyplot as plt\n", 17 | "from os.path import join, exists\n", 18 | "import numpy as np\n", 19 | "from tqdm import tqdm\n", 20 | "import random\n", 21 | "%matplotlib inline" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 2, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "with open('results/ood_prob_dict.pickle', 'rb') as f:\n", 31 | " ood_prob_dict = pickle.load(f)" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": 3, 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "with open('results/HR_prob_dict.pickle', 'rb') as f:\n", 41 | " HR_prob_dict = pickle.load(f)" 42 | ] 43 | }, 44 | { 45 | "cell_type": "markdown", 46 | "metadata": {}, 47 | "source": [ 48 | "# Generate anchor image dict for deriving image pair (ref + target) dataset\n", 49 | "\n", 50 | "A \"anchor_images_dict\" maps a target image in a sequence to all its reference images in this sequence. This is needed to be run before deploying the LR model." 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": 7, 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "dir_list = ['demo_sequences']" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": 8, 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [ 68 | "root_data_dir = 'data/sequences'" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": 9, 74 | "metadata": {}, 75 | "outputs": [], 76 | "source": [ 77 | "HR_threshold = 0.85" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": 10, 83 | "metadata": {}, 84 | "outputs": [], 85 | "source": [ 86 | "def anchor_model_1_get_all_images(prob_dict, fname_list, threshold=0.5):\n", 87 | " \"\"\"\n", 88 | " prob_dict: {fname: probability}, e.g. {'2006_0.png': 0.001, '2007_1.png': 0.132, ...}\n", 89 | " fname_list: a candidate list of filenames\n", 90 | " \"\"\"\n", 91 | " anchor_images = []\n", 92 | " for fname in fname_list:\n", 93 | " prob = prob_dict[fname]\n", 94 | " if prob >= threshold:\n", 95 | " anchor_images.append(fname)\n", 96 | " if len(anchor_images) == 0:\n", 97 | " anchor_images.append(fname_list[-1])\n", 98 | " return anchor_images" 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": 11, 104 | "metadata": {}, 105 | "outputs": [], 106 | "source": [ 107 | "anchor_images_dict = dict()\n", 108 | "for subdir in dir_list:\n", 109 | " data_dir = join(root_data_dir, subdir)\n", 110 | " for folder in os.listdir(data_dir):\n", 111 | " idx = folder.split('_')[0]\n", 112 | " folder_dir = join(data_dir, folder)\n", 113 | " fname_list = []\n", 114 | " for f in os.listdir(folder_dir):\n", 115 | " if f[-4:] == '.png':\n", 116 | " fname_list.append(f)\n", 117 | " HR_prob_dict_sub = HR_prob_dict[idx]\n", 118 | " anchor_images = anchor_model_1_get_all_images(HR_prob_dict_sub, fname_list, threshold=HR_threshold)\n", 119 | " anchor_images_dict[idx] = anchor_images" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": 12, 125 | "metadata": {}, 126 | "outputs": [], 127 | "source": [ 128 | "with open('results/anchor_images_dict.pickle', 'wb') as f:\n", 129 | " pickle.dump(anchor_images_dict, f)" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": null, 135 | "metadata": {}, 136 | "outputs": [], 137 | "source": [] 138 | } 139 | ], 140 | "metadata": { 141 | "kernelspec": { 142 | "display_name": "Environment (conda_tensorflow_p36)", 143 | "language": "python", 144 | "name": "conda_tensorflow_p36" 145 | }, 146 | "language_info": { 147 | "codemirror_mode": { 148 | "name": "ipython", 149 | "version": 3 150 | }, 151 | "file_extension": ".py", 152 | "mimetype": "text/x-python", 153 | "name": "python", 154 | "nbconvert_exporter": "python", 155 | "pygments_lexer": "ipython3", 156 | "version": "3.6.5" 157 | } 158 | }, 159 | "nbformat": 4, 160 | "nbformat_minor": 4 161 | } 162 | -------------------------------------------------------------------------------- /hp_search_HR.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import division 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | from torch.utils.data import Dataset, DataLoader 7 | import torchvision 8 | from torchvision import datasets, models, transforms, utils 9 | import torchvision.transforms.functional as TF 10 | 11 | from tqdm import tqdm 12 | import numpy as np 13 | import pandas as pd 14 | import pickle 15 | import matplotlib.pyplot as plt 16 | # import skimage 17 | # import skimage.io 18 | # import skimage.transform 19 | from PIL import Image 20 | import time 21 | import os 22 | from os.path import join, exists 23 | import copy 24 | import random 25 | from collections import OrderedDict 26 | from sklearn.metrics import r2_score 27 | 28 | from torch.nn import functional as F 29 | from torchvision.models import Inception3 30 | 31 | from utils.image_dataset import BinaryImageFolder 32 | 33 | """ 34 | This script is for training solar panel classification model with high-resolution (HR) 35 | image input. It is a single-branch CNN based on Inception v3 model. The hyperparameters 36 | to search include learning rate and learning rate decay epochs. 37 | """ 38 | 39 | # Configuration 40 | # directory for loading training/validation/test data 41 | data_dirs_dict = { 42 | 'train': ['data/HR_images/train'], 43 | 'val': ['data/HR_images/val'] 44 | } 45 | 46 | 47 | # path to load old model/checkpoint, "None" if not loading. 48 | old_ckpt_path = None 49 | # directory for saving model/checkpoint 50 | ckpt_save_dir = 'checkpoint/HR_new_model' 51 | 52 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 53 | trainable_params = None # layers or modules set to be trainable. "None" if training all layers 54 | model_name = 'HR' # the prefix of the filename for saving model/checkpoint 55 | return_best = True # whether to return the best model according to the validation metrics 56 | if_early_stop = True # whether to stop early after validation metrics doesn't improve for definite number of epochs 57 | input_size = 299 # image size fed into the model 58 | imbalance_rate = 1.0 # weight given to the positive (rarer) samples in loss function 59 | # learning_rate = 0.01 # learning rate 60 | weight_decay = 0.00 # l2 regularization coefficient 61 | batch_size = 64 62 | num_epochs = 100 # number of epochs to train 63 | lr_decay_rate = 0.5 # learning rate decay rate for each decay step 64 | # lr_decay_epochs = 10 # number of epochs for one learning rate decay 65 | early_stop_epochs = 10 # after validation metrics doesn't improve for "early_stop_epochs" epochs, stop the training. 66 | save_epochs = 10 # save the model/checkpoint every "save_epochs" epochs 67 | # threshold = 0.2 # threshold probability to identify am image as positive 68 | threshold_list = np.linspace(0.0, 1.0, 101) 69 | 70 | # hyperparamters to tune 71 | lr_list = [0.0001, 0.001, 0.00001] # learning rate 72 | lr_decay_epochs_list = [10] # learning rate decay epochs 73 | 74 | 75 | def RandomRotationNew(image): 76 | angle = random.choice([0, 90, 180, 270]) 77 | image = TF.rotate(image, angle) 78 | return image 79 | 80 | def only_train(model, trainable_params): 81 | """trainable_params: The list of parameters and modules that are set to be trainable. 82 | Set require_grad = False for all those parameters not in the trainable_params""" 83 | print('Only the following layers:') 84 | for name, p in model.named_parameters(): 85 | p.requires_grad = False 86 | for target in trainable_params: 87 | if target == name or target in name: 88 | p.requires_grad = True 89 | print(' ' + name) 90 | break 91 | 92 | def metrics(stats): 93 | """stats: {'TP': TP, 'FP': FP, 'TN': TN, 'FN': FN} 94 | return: must be a single number """ 95 | # precision = (stats['TP'] + 0.00001) * 1.0 / (stats['TP'] + stats['FP'] + 0.00001) 96 | # recall = (stats['TP'] + 0.00001) * 1.0 / (stats['TP'] + stats['FN'] + 0.00001) 97 | # F1 = 2.0 * stats['TP'] / (2 * stats['TP'] + stats['FP'] + stats['FN']) 98 | spec = (stats['TN'] + 0.00001) * 1.0 / (stats['TN'] + stats['FP'] + 0.00001) 99 | sens = (stats['TP'] + 0.00001) * 1.0 / (stats['TP'] + stats['FN'] + 0.00001) 100 | return 2.0 * spec * sens / (spec + sens + 1e-7) 101 | 102 | 103 | def train_model(model, model_name, dataloaders, criterion, optimizer, metrics, num_epochs, training_log=None, 104 | verbose=True, return_best=True, if_early_stop=True, early_stop_epochs=10, scheduler=None, 105 | save_dir=None, save_epochs=5): 106 | since = time.time() 107 | if not training_log: 108 | training_log = dict() 109 | training_log['train_loss_history'] = [] 110 | training_log['val_loss_history'] = [] 111 | training_log['val_metric_value_history'] = [] 112 | training_log['epoch_best_threshold_history'] = [] 113 | training_log['current_epoch'] = -1 114 | current_epoch = training_log['current_epoch'] + 1 115 | 116 | best_model_wts = copy.deepcopy(model.state_dict()) 117 | best_optimizer_wts = copy.deepcopy(optimizer.state_dict()) 118 | best_log = copy.deepcopy(training_log) 119 | 120 | best_metric_value = -np.inf 121 | best_threshold = 0 122 | nodecrease = 0 # to count the epochs that val loss doesn't decrease 123 | early_stop = False 124 | 125 | for epoch in range(current_epoch, current_epoch + num_epochs): 126 | if verbose: 127 | print('Epoch {}/{}'.format(epoch, num_epochs - 1)) 128 | print('-' * 10) 129 | 130 | # Each epoch has a training and validation phase 131 | for phase in ['train', 'val']: 132 | if phase == 'train': 133 | model.train() # Set model to training mode 134 | else: 135 | model.eval() # Set model to evaluate mode 136 | 137 | running_loss = 0.0 138 | stats = {x: {'TP': 0, 'FP': 0, 'TN': 0, 'FN': 0} for x in threshold_list} 139 | 140 | # Iterate over data. 141 | for inputs, labels, _ in tqdm(dataloaders[phase]): 142 | inputs = inputs.to(device) 143 | labels = labels.to(device) 144 | 145 | # zero the parameter gradients 146 | optimizer.zero_grad() 147 | 148 | # forward 149 | # track history if only in train 150 | with torch.set_grad_enabled(phase == 'train'): 151 | # Get model outputs and calculate loss 152 | if phase == 'train': 153 | outputs, aux_outputs = model(inputs) 154 | loss1 = criterion(outputs, labels) 155 | loss2 = criterion(aux_outputs, labels) 156 | loss = loss1 + 0.4*loss2 157 | # backward + optimize only if in training phase 158 | loss.backward() 159 | optimizer.step() 160 | 161 | else: 162 | outputs = model(inputs) 163 | loss = criterion(outputs, labels) 164 | # val phase: calculate metrics under different threshold 165 | prob = F.softmax(outputs, dim=1) 166 | for threshold in threshold_list: 167 | preds = prob[:, 1] >= threshold 168 | stats[threshold]['TP'] += torch.sum((preds == 1) * (labels == 1)).cpu().item() 169 | stats[threshold]['TN'] += torch.sum((preds == 0) * (labels == 0)).cpu().item() 170 | stats[threshold]['FP'] += torch.sum((preds == 1) * (labels == 0)).cpu().item() 171 | stats[threshold]['FN'] += torch.sum((preds == 0) * (labels == 1)).cpu().item() 172 | 173 | # loss accumulation 174 | running_loss += loss.item() * inputs.size(0) 175 | 176 | training_log['current_epoch'] = epoch 177 | epoch_loss = running_loss / len(dataloaders[phase].dataset) 178 | 179 | if phase == 'train': 180 | training_log['train_loss_history'].append(epoch_loss) 181 | if scheduler is not None: 182 | scheduler.step() 183 | if verbose: 184 | print('{} Loss: {:.4f}'.format(phase, epoch_loss)) 185 | 186 | if phase == 'val': 187 | epoch_best_threshold = 0.0 188 | epoch_max_metrics = 0.0 189 | for threshold in threshold_list: 190 | metric_value = metrics(stats[threshold]) 191 | if metric_value > epoch_max_metrics: 192 | epoch_best_threshold = threshold 193 | epoch_max_metrics = metric_value 194 | spec = (stats[epoch_best_threshold]['TN'] + 0.00001) * 1.0 / (stats[epoch_best_threshold]['TN'] + stats[epoch_best_threshold]['FP'] + 0.00001) 195 | sens = (stats[epoch_best_threshold]['TP'] + 0.00001) * 1.0 / (stats[epoch_best_threshold]['TP'] + stats[epoch_best_threshold]['FN'] + 0.00001) 196 | 197 | if verbose: 198 | print('{} Loss: {:.4f} Metrics: {:.4f} Threshold: {:.4f} Sensitivity: {:.4f} Specificity {:4f}'.format(phase, epoch_loss, 199 | epoch_max_metrics, epoch_best_threshold, sens, spec)) 200 | 201 | training_log['val_metric_value_history'].append(epoch_max_metrics) 202 | training_log['val_loss_history'].append(epoch_loss) 203 | training_log['epoch_best_threshold_history'].append(epoch_best_threshold) 204 | 205 | # deep copy the model 206 | if epoch_max_metrics > best_metric_value: 207 | best_metric_value = epoch_max_metrics 208 | best_threshold = epoch_best_threshold 209 | best_model_wts = copy.deepcopy(model.state_dict()) 210 | best_optimizer_wts = copy.deepcopy(optimizer.state_dict()) 211 | best_log = copy.deepcopy(training_log) 212 | nodecrease = 0 213 | else: 214 | nodecrease += 1 215 | 216 | if nodecrease >= early_stop_epochs: 217 | early_stop = True 218 | 219 | if save_dir and epoch % save_epochs == 0: 220 | checkpoint = { 221 | 'model_state_dict': model.state_dict(), 222 | 'optimizer_state_dict': optimizer.state_dict(), 223 | 'training_log': training_log 224 | } 225 | torch.save(checkpoint, 226 | os.path.join(save_dir, model_name + '_' + str(training_log['current_epoch']) + '.tar')) 227 | 228 | if if_early_stop and early_stop: 229 | print('Early stopped!') 230 | break 231 | 232 | time_elapsed = time.time() - since 233 | print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60)) 234 | print('Best validation metric value: {:4f}'.format(best_metric_value)) 235 | print('Best validation threshold: {:4f}'.format(best_threshold)) 236 | 237 | # load best model weights 238 | if return_best: 239 | model.load_state_dict(best_model_wts) 240 | optimizer.load_state_dict(best_optimizer_wts) 241 | training_log = best_log 242 | 243 | checkpoint = { 244 | 'model_state_dict': model.state_dict(), 245 | 'optimizer_state_dict': optimizer.state_dict(), 246 | 'training_log': training_log 247 | } 248 | torch.save(checkpoint, 249 | os.path.join(save_dir, model_name + '_' + str(training_log['current_epoch']) + '_last.tar')) 250 | 251 | return model, training_log, best_metric_value, best_threshold 252 | 253 | 254 | data_transforms = { 255 | 'train': transforms.Compose([ 256 | transforms.Resize((input_size, input_size)), 257 | transforms.Lambda(RandomRotationNew), 258 | transforms.RandomHorizontalFlip(p=0.5), 259 | transforms.RandomVerticalFlip(p=0.5), 260 | transforms.ToTensor(), 261 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) 262 | ]), 263 | 'val': transforms.Compose([ 264 | transforms.Resize((input_size, input_size)), 265 | transforms.ToTensor(), 266 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) 267 | ]) 268 | } 269 | 270 | 271 | if __name__ == '__main__': 272 | # data 273 | image_datasets = {x: BinaryImageFolder(data_dirs_dict[x], data_transforms[x]) for x in ['train', 'val']} 274 | dataloaders_dict = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, 275 | shuffle=True, num_workers=4) for x in ['train', 'val']} 276 | 277 | results_dict = {x: {y: {} for y in lr_list} for x in lr_decay_epochs_list} 278 | 279 | if not os.path.exists(ckpt_save_dir): 280 | os.mkdir(ckpt_save_dir) 281 | 282 | # model 283 | for lr_decay_epochs in lr_decay_epochs_list: 284 | for learning_rate in lr_list: 285 | print('----------------------- ' + str(lr_decay_epochs) + ', ' + str(learning_rate) + ' -----------------------') 286 | model = Inception3(num_classes=2, aux_logits=True, transform_input=False) 287 | model = model.to(device) 288 | optimizer = optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), eps=1e-08, 289 | weight_decay=weight_decay, amsgrad=True) 290 | class_weight = torch.tensor([1, imbalance_rate], dtype=torch.float).cuda() 291 | loss_fn = nn.CrossEntropyLoss(weight=class_weight) 292 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=lr_decay_epochs, gamma=lr_decay_rate) 293 | 294 | # load old parameters 295 | if old_ckpt_path: 296 | checkpoint = torch.load(old_ckpt_path, map_location=device) 297 | if old_ckpt_path[-4:] == '.tar': # it is a checkpoint dictionary rather than just model parameters 298 | model.load_state_dict(checkpoint['model_state_dict']) 299 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 300 | training_log = checkpoint['training_log'] 301 | else: 302 | model.load_state_dict(checkpoint, strict=False) 303 | training_log = None 304 | print('Old checkpoint loaded: ' + old_ckpt_path) 305 | else: 306 | training_log = None # start from scratch 307 | 308 | # fix some layers and make others trainable 309 | if trainable_params: 310 | only_train(model, trainable_params) 311 | 312 | _, _, best_metric_value, best_threshold = train_model(model, model_name=model_name+'_decay_'+str(lr_decay_epochs)+'_lr_'+str(learning_rate), 313 | dataloaders=dataloaders_dict, criterion=loss_fn, 314 | optimizer=optimizer, metrics=metrics, num_epochs=num_epochs, 315 | training_log=training_log, verbose=True, return_best=return_best, 316 | if_early_stop=if_early_stop, early_stop_epochs=early_stop_epochs, 317 | scheduler=scheduler, save_dir=ckpt_save_dir, save_epochs=save_epochs) 318 | results_dict[lr_decay_epochs][learning_rate] = {'metrics': best_metric_value, 'threshold': best_threshold} 319 | 320 | with open(join(ckpt_save_dir, 'results_dict.pickle'), 'wb') as f: 321 | pickle.dump(results_dict, f) 322 | 323 | for lr_decay_epochs in lr_decay_epochs_list: 324 | for learning_rate in lr_list: 325 | print('lr_decay_epochs: ', lr_decay_epochs, 'learning_rate: ', learning_rate, 326 | 'metric: ', results_dict[lr_decay_epochs][learning_rate]['metrics'], 327 | 'threshold: ', results_dict[lr_decay_epochs][learning_rate]['threshold']) 328 | -------------------------------------------------------------------------------- /hp_search_LR_rgb.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import division 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | from torch.utils.data import Dataset, DataLoader 7 | import torchvision 8 | from torchvision import datasets, models, transforms, utils 9 | import torchvision.transforms.functional as TF 10 | 11 | from tqdm import tqdm 12 | import numpy as np 13 | import pandas as pd 14 | import pickle 15 | import matplotlib.pyplot as plt 16 | # import skimage 17 | # import skimage.io 18 | # import skimage.transform 19 | from PIL import Image 20 | import time 21 | import os 22 | from os.path import join, exists 23 | import copy 24 | import random 25 | from collections import OrderedDict 26 | from sklearn.metrics import r2_score 27 | 28 | from torch.nn import functional as F 29 | from torchvision.models import Inception3 30 | 31 | from utils.image_dataset import * 32 | from LR_models.siamese_model_rgb import * 33 | 34 | """ 35 | This script is for training solar panel classification model for low-resolution (LR) 36 | images. The model is a pseudo-Siamese network that takes one target image and one 37 | reference image as inputs and predicts whether the reference image contains solar 38 | or not. The hyperparameters to search include the depth of convolutional layers, 39 | number of convolutional layers, and number of filters. 40 | """ 41 | 42 | # Configuration 43 | # directory for loading training/validation/test data 44 | data_dirs_dict = { 45 | 'train': ['data/LR_images/train'], 46 | 'val': ['data/LR_images/val'], 47 | 'test': ['data/LR_images/test'], 48 | } 49 | 50 | # each pickle file is a dict mapping a relative path of target image to 51 | # a list of relative paths of its corresponding reference images. 52 | # The length of the path list under each subset should match the length 53 | # of each subset of data_dirs_dict. E.g., len(data_dirs_dict['train']) 54 | # must be equal to len(reference_mapping_paths_dict['train']) 55 | reference_mapping_paths_dict = { 56 | 'train': ['data/LR_images/train/reference_mapping_train.pickle'], 57 | 'val': ['data/LR_images/val/reference_mapping_val.pickle'], 58 | 'test': ['data/LR_images/test/reference_mapping_test.pickle'] 59 | } 60 | 61 | # paths to load old model/checkpoint. 62 | old_ckpt_path_dict = { 63 | 'resnet34': 'checkpoint/resnet34-333f7ec4.pth', 64 | 'resnet50': 'checkpoint/resnet50-19c8e357.pth', 65 | } 66 | # directory for saving model/checkpoint 67 | ckpt_save_dir = 'checkpoint/LR_new_model' 68 | 69 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 70 | backbone = 'resnet34' 71 | old_ckpt_path = old_ckpt_path_dict[backbone] # set it to be None if not loading any. 72 | trainable_params = None # layers or modules set to be trainable. "None" if training all layers 73 | model_name = 'LR' # the prefix of the filename for saving model/checkpoint 74 | return_best = True # whether to return the best model according to the validation metrics 75 | if_early_stop = True # whether to stop early after validation metrics doesn't improve for definite number of epochs 76 | input_size = 299 # image size fed into the model 77 | imbalance_rate = 1.0 # weight given to the positive (rarer) samples in loss function 78 | learning_rate = 0.0001 # learning rate 79 | weight_decay = 0 # l2 regularization coefficient 80 | batch_size = 64 81 | num_epochs = 100 # nls 82 | # umber of epochs to train 83 | lr_decay_rate = 0.95 # learning rate decay rate for each decay step 84 | lr_decay_epochs = 10 # number of epochs for one learning rate decay 85 | early_stop_epochs = 10 # after validation metrics doesn't improve for "early_stop_epochs" epochs, stop the training. 86 | save_epochs = 50 # save the model/checkpoint every "save_epochs" epochs 87 | # threshold = 0.2 # threshold probability to identify am image as positive 88 | # nfilters = 256 89 | # lr_list = [0.0001] 90 | # lr_decay_epochs_list = [10, 4] 91 | 92 | # The hyperparameters to search 93 | depth_list = [128] 94 | nconvs_list = [3, 2] 95 | nfilters_list = [512, 384, 256] 96 | # weight_decay_list = [0] 97 | threshold_list = np.linspace(0.0, 1.0, 101).tolist() 98 | 99 | 100 | def RandomRotationNew(image): 101 | angle = random.choice([0, 90, 180, 270]) 102 | image = TF.rotate(image, angle) 103 | return image 104 | 105 | 106 | def mask_image_info(img): 107 | img = np.array(img) 108 | img[0:18, 0:95] = 0 109 | img[289:298, 0:299] = 0 110 | # img[256:263, 122:202] = 0 111 | img = Image.fromarray(img) 112 | return img 113 | 114 | 115 | def only_train(model, trainable_params): 116 | """trainable_params: The list of parameters and modules that are set to be trainable. 117 | Set require_grad = False for all those parameters not in the trainable_params""" 118 | print('Only the following layers:') 119 | for name, p in model.named_parameters(): 120 | p.requires_grad = False 121 | for target in trainable_params: 122 | if target == name or target in name: 123 | p.requires_grad = True 124 | print(' ' + name) 125 | break 126 | 127 | 128 | def metrics(stats): 129 | """stats: {'TP': TP, 'FP': FP, 'TN': TN, 'FN': FN} 130 | return: must be a single number """ 131 | # precision = (stats['TP'] + 0.00001) * 1.0 / (stats['TP'] + stats['FP'] + 0.00001) 132 | # recall = (stats['TP'] + 0.00001) * 1.0 / (stats['TP'] + stats['FN'] + 0.00001) 133 | # F1 = 2.0 * stats['TP'] / (2 * stats['TP'] + stats['FP'] + stats['FN']) 134 | spec = (stats['TN'] + 0.00001) * 1.0 / (stats['TN'] + stats['FP'] + 0.00001) 135 | sens = (stats['TP'] + 0.00001) * 1.0 / (stats['TP'] + stats['FN'] + 0.00001) 136 | return 2.0 * spec * sens / (spec + sens + 1e-7) 137 | 138 | 139 | def train_model(model, model_name, dataloaders, criterion, optimizer, metrics, num_epochs, training_log=None, 140 | verbose=True, return_best=True, if_early_stop=True, early_stop_epochs=10, scheduler=None, 141 | save_dir=None, save_epochs=5): 142 | since = time.time() 143 | if not training_log: 144 | training_log = dict() 145 | training_log['train_loss_history'] = [] 146 | training_log['val_loss_history'] = [] 147 | training_log['val_metric_value_history'] = [] 148 | training_log['epoch_best_threshold_history'] = [] 149 | training_log['current_epoch'] = -1 150 | current_epoch = training_log['current_epoch'] + 1 151 | 152 | best_model_wts = copy.deepcopy(model.state_dict()) 153 | best_optimizer_wts = copy.deepcopy(optimizer.state_dict()) 154 | best_log = copy.deepcopy(training_log) 155 | 156 | best_metric_value = -np.inf 157 | best_threshold = 0 158 | nodecrease = 0 # to count the epochs that val loss doesn't decrease 159 | early_stop = False 160 | 161 | for epoch in range(current_epoch, current_epoch + num_epochs): 162 | if verbose: 163 | print('Epoch {}/{}'.format(epoch, num_epochs - 1)) 164 | print('-' * 10) 165 | 166 | # Each epoch has a training and validation phase 167 | for phase in ['train', 'val']: 168 | if phase == 'train': 169 | model.train() # Set model to training mode 170 | else: 171 | model.eval() # Set model to evaluate mode 172 | 173 | running_loss = 0.0 174 | stats = {x: {'TP': 0, 'FP': 0, 'TN': 0, 'FN': 0} for x in threshold_list} 175 | 176 | # Iterate over data. 177 | for inputs_ref, inputs_tar, labels in tqdm(dataloaders[phase]): 178 | inputs_ref = inputs_ref.to(device) 179 | inputs_tar = inputs_tar.to(device) 180 | labels = labels.to(device) 181 | 182 | # zero the parameter gradients 183 | optimizer.zero_grad() 184 | 185 | # forward 186 | # track history if only in train 187 | with torch.set_grad_enabled(phase == 'train'): 188 | # Get model outputs and calculate loss 189 | if phase == 'train': 190 | outputs = model(inputs_tar, inputs_ref) 191 | loss = criterion(outputs, labels) 192 | # backward + optimize only if in training phase 193 | loss.backward() 194 | optimizer.step() 195 | 196 | else: 197 | outputs = model(inputs_tar, inputs_ref) 198 | loss = criterion(outputs, labels) 199 | # val phase: calculate metrics under different threshold 200 | prob = F.softmax(outputs, dim=1) 201 | for threshold in threshold_list: 202 | preds = prob[:, 1] >= threshold 203 | stats[threshold]['TP'] += torch.sum((preds == 1) * (labels == 1)).cpu().item() 204 | stats[threshold]['TN'] += torch.sum((preds == 0) * (labels == 0)).cpu().item() 205 | stats[threshold]['FP'] += torch.sum((preds == 1) * (labels == 0)).cpu().item() 206 | stats[threshold]['FN'] += torch.sum((preds == 0) * (labels == 1)).cpu().item() 207 | 208 | # loss accumulation 209 | running_loss += loss.item() * inputs_ref.size(0) 210 | 211 | training_log['current_epoch'] = epoch 212 | epoch_loss = running_loss / len(dataloaders[phase].dataset) 213 | 214 | if phase == 'train': 215 | training_log['train_loss_history'].append(epoch_loss) 216 | if scheduler is not None: 217 | scheduler.step() 218 | if verbose: 219 | print('{} Loss: {:.4f}'.format(phase, epoch_loss)) 220 | 221 | if phase == 'val': 222 | epoch_best_threshold = 0.0 223 | epoch_max_metrics = 0.0 224 | for threshold in threshold_list: 225 | metric_value = metrics(stats[threshold]) 226 | if metric_value > epoch_max_metrics: 227 | epoch_best_threshold = threshold 228 | epoch_max_metrics = metric_value 229 | spec = (stats[epoch_best_threshold]['TN'] + 0.00001) * 1.0 / (stats[epoch_best_threshold]['TN'] + stats[epoch_best_threshold]['FP'] + 0.00001) 230 | sens = (stats[epoch_best_threshold]['TP'] + 0.00001) * 1.0 / (stats[epoch_best_threshold]['TP'] + stats[epoch_best_threshold]['FN'] + 0.00001) 231 | 232 | if verbose: 233 | print('{} Loss: {:.4f} Metrics: {:.4f} Threshold: {:.4f} Sensitivity: {:.4f} Specificity {:4f}'.format(phase, epoch_loss, 234 | epoch_max_metrics, epoch_best_threshold, sens, spec)) 235 | 236 | training_log['val_metric_value_history'].append(epoch_max_metrics) 237 | training_log['val_loss_history'].append(epoch_loss) 238 | training_log['epoch_best_threshold_history'].append(epoch_best_threshold) 239 | 240 | # deep copy the model 241 | if epoch_max_metrics > best_metric_value: 242 | best_metric_value = epoch_max_metrics 243 | best_threshold = epoch_best_threshold 244 | best_model_wts = copy.deepcopy(model.state_dict()) 245 | best_optimizer_wts = copy.deepcopy(optimizer.state_dict()) 246 | best_log = copy.deepcopy(training_log) 247 | nodecrease = 0 248 | else: 249 | nodecrease += 1 250 | 251 | if nodecrease >= early_stop_epochs: 252 | early_stop = True 253 | 254 | if save_dir and epoch % save_epochs == 0 and epoch > 0: 255 | checkpoint = { 256 | 'model_state_dict': model.state_dict(), 257 | 'optimizer_state_dict': optimizer.state_dict(), 258 | 'training_log': training_log 259 | } 260 | torch.save(checkpoint, 261 | os.path.join(save_dir, model_name + '_' + str(training_log['current_epoch']) + '.tar')) 262 | 263 | if if_early_stop and early_stop: 264 | print('Early stopped!') 265 | break 266 | 267 | time_elapsed = time.time() - since 268 | print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60)) 269 | print('Best validation metric value: {:4f}'.format(best_metric_value)) 270 | print('Best validation threshold: {:4f}'.format(best_threshold)) 271 | 272 | # load best model weights 273 | if return_best: 274 | model.load_state_dict(best_model_wts) 275 | optimizer.load_state_dict(best_optimizer_wts) 276 | training_log = best_log 277 | 278 | checkpoint = { 279 | 'model_state_dict': model.state_dict(), 280 | 'optimizer_state_dict': optimizer.state_dict(), 281 | 'training_log': training_log 282 | } 283 | torch.save(checkpoint, 284 | os.path.join(save_dir, model_name + '_' + str(training_log['current_epoch']) + '_last.tar')) 285 | 286 | return model, training_log, best_metric_value, best_threshold 287 | 288 | 289 | def test_model(model, dataloader, metrics, threshold_list): 290 | stats = {x: {'TP': 0, 'FP': 0, 'TN': 0, 'FN': 0} for x in threshold_list} 291 | metric_values = {} 292 | model.eval() 293 | for inputs_ref, inputs_tar, labels in tqdm(dataloader): 294 | inputs_ref = inputs_ref.to(device) 295 | inputs_tar = inputs_tar.to(device) 296 | labels = labels.to(device) 297 | 298 | with torch.set_grad_enabled(False): 299 | outputs = model(inputs_tar, inputs_ref) 300 | prob = F.softmax(outputs, dim=1) 301 | for threshold in threshold_list: 302 | preds = prob[:, 1] >= threshold 303 | stats[threshold]['TP'] += torch.sum((preds == 1) * (labels == 1)).cpu().item() 304 | stats[threshold]['TN'] += torch.sum((preds == 0) * (labels == 0)).cpu().item() 305 | stats[threshold]['FP'] += torch.sum((preds == 1) * (labels == 0)).cpu().item() 306 | stats[threshold]['FN'] += torch.sum((preds == 0) * (labels == 1)).cpu().item() 307 | 308 | for threshold in threshold_list: 309 | metric_values[threshold] = metrics(stats[threshold]) 310 | 311 | return stats, metric_values 312 | 313 | 314 | data_transforms = { 315 | 'train': transforms.Compose([ 316 | transforms.Resize((input_size, input_size)), # transfer to input size 317 | # transforms.Lambda(mask_image_info), 318 | transforms.ToTensor() 319 | ]), 320 | 'val': transforms.Compose([ 321 | transforms.Resize((input_size, input_size)), 322 | # transforms.Lambda(mask_image_info), 323 | transforms.ToTensor() 324 | ]), 325 | 'test': transforms.Compose([ 326 | transforms.Resize((input_size, input_size)), 327 | # transforms.Lambda(mask_image_info), 328 | transforms.ToTensor() 329 | ]) 330 | } 331 | 332 | 333 | if __name__ == '__main__': 334 | # data 335 | image_datasets = {x: ImagePairDataset(data_dirs_dict[x], 336 | reference_mapping_paths_dict[x], 337 | is_train=(x == 'train'), 338 | binary=False, 339 | transform=data_transforms[x]) for x in ['train', 'val', 'test']} 340 | dataloaders_dict = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, 341 | shuffle=True, num_workers=4) for x in ['train', 'val', 'test']} 342 | 343 | print('Training set size: ' + str(len(image_datasets['train']))) 344 | print('Validation set size: ' + str(len(image_datasets['val']))) 345 | print('Test set size: ' + str(len(image_datasets['test']))) 346 | 347 | results_dict = {x: {y: {z: {} for z in nfilters_list} for y in depth_list} for x in nconvs_list} 348 | 349 | if not os.path.exists(ckpt_save_dir): 350 | os.mkdir(ckpt_save_dir) 351 | 352 | # model 353 | for nconvs in nconvs_list: 354 | for depth in depth_list: 355 | for nfilters in nfilters_list: 356 | print('----------------------- ' + 357 | str(nconvs) + ', ' + 358 | str(depth) + ', ' + 359 | str(nfilters) + 360 | ' -----------------------') 361 | model = psn_depthwise_cc_layerwise_3layers_l234(backbone=backbone, nconvs=nconvs, depth=depth, nfilters=nfilters, 362 | kernel_size=3) 363 | optimizer = optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), eps=1e-08, 364 | weight_decay=weight_decay, amsgrad=True) 365 | class_weight = torch.tensor([1, imbalance_rate], dtype=torch.float).cuda() 366 | loss_fn = nn.CrossEntropyLoss(weight=class_weight) 367 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=lr_decay_epochs, gamma=lr_decay_rate) 368 | 369 | # load old parameters 370 | if old_ckpt_path: 371 | checkpoint = torch.load(old_ckpt_path) 372 | if old_ckpt_path[-4:] == '.tar': # it is a checkpoint dictionary rather than just model parameters 373 | model.load_state_dict(checkpoint['model_state_dict']) 374 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 375 | training_log = checkpoint['training_log'] 376 | 377 | else: 378 | del checkpoint['fc.weight'] 379 | del checkpoint['fc.bias'] 380 | model.net1.load_state_dict(checkpoint, strict=False) 381 | model.net2.load_state_dict(checkpoint, strict=False) 382 | training_log = None # start from scratch 383 | 384 | print('Old checkpoint loaded: ' + old_ckpt_path) 385 | 386 | model = model.to(device) 387 | 388 | # fix some layers and make others trainable 389 | if trainable_params: 390 | only_train(model, trainable_params) 391 | 392 | best_model, _, best_metric_value, best_threshold = train_model(model, model_name=model_name + '_nconvs_' + str( 393 | nconvs) + '_depth_' + str(depth) + '_nfilters_' + str(nfilters), 394 | dataloaders=dataloaders_dict, criterion=loss_fn, 395 | optimizer=optimizer, metrics=metrics, 396 | num_epochs=num_epochs, 397 | training_log=training_log, verbose=True, 398 | return_best=return_best, 399 | if_early_stop=if_early_stop, 400 | early_stop_epochs=early_stop_epochs, 401 | scheduler=scheduler, save_dir=ckpt_save_dir, 402 | save_epochs=save_epochs) 403 | 404 | print('Begin test ...') 405 | test_stats, test_metric_values = test_model(best_model, dataloaders_dict['test'], metrics, threshold_list=threshold_list) 406 | 407 | best_threshold_test, best_metric_value_test = sorted(list(test_metric_values.items()), key=lambda tup: tup[1], reverse=True)[0] 408 | 409 | results_dict[nconvs][depth][nfilters] = {'best_metrics_val': best_metric_value, 410 | 'best_threshold_val': best_threshold, 411 | 'best_metric_test': best_metric_value_test, 412 | 'best_threshold_test': best_threshold_test, 413 | 'test_metrics_with_val_best_threshold': 414 | test_metric_values[best_threshold], 415 | } 416 | 417 | with open(join(ckpt_save_dir, 'results_dict.pickle'), 'wb') as f: 418 | pickle.dump(results_dict, f) 419 | 420 | print(results_dict) 421 | -------------------------------------------------------------------------------- /hp_search_ood_multilabels.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import division 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | from torch.utils.data import Dataset, DataLoader 7 | import torchvision 8 | from torchvision import datasets, models, transforms, utils 9 | import torchvision.transforms.functional as TF 10 | 11 | from tqdm import tqdm 12 | import numpy as np 13 | import pandas as pd 14 | import pickle 15 | import matplotlib.pyplot as plt 16 | # import skimage 17 | # import skimage.io 18 | # import skimage.transform 19 | from PIL import Image 20 | import time 21 | import os 22 | from os.path import join, exists 23 | import copy 24 | import random 25 | from collections import OrderedDict 26 | from sklearn.metrics import r2_score 27 | 28 | from torch.nn import functional as F 29 | from torchvision.models import Inception3, resnet18, resnet34, resnet50 30 | 31 | from utils.image_dataset import * 32 | from LR_models.siamese_model_rgb import * 33 | 34 | """ 35 | This script is for training blur detection model that classifies an image 36 | into three classes: OOD (out of distribution, extremely blurred), HR (high 37 | resolution), LR (low resolution). It is a single-branch CNN based on 38 | ResNet-50 model. The hyperparameters to search include learning rate, 39 | learning rate decay epochs, and weight decay. 40 | """ 41 | 42 | # Configuration 43 | # directory for loading training/validation/test data 44 | # for each of "train"/"val"/"test", put the image folders of class "OOD" to the first 45 | # list, put the image folders of class "LR" to the second list, and put image folders 46 | # of class "HR" to the third list. 47 | 48 | dirs_list_dict = { 49 | 'train': 50 | [[ 51 | 'data/blur_detection_images/train/OOD', 52 | ], 53 | [ 54 | 'data/blur_detection_images/train/LR', 55 | ], 56 | [ 57 | 'data/blur_detection_images/train/HR', 58 | ]], 59 | 'val': 60 | [[ 61 | 'data/blur_detection_images/val/OOD', 62 | ], 63 | [ 64 | 'data/blur_detection_images/val/LR', 65 | ], 66 | [ 67 | 'data/blur_detection_images/val/HR', 68 | ]], 69 | 'test': 70 | [[ 71 | 'data/blur_detection_images/test/OOD', 72 | ], 73 | [ 74 | 'data/blur_detection_images/test/LR', 75 | ], 76 | [ 77 | 'data/blur_detection_images/test/HR', 78 | ]], 79 | } 80 | 81 | old_ckpt_path_dict = { 82 | 'resnet34': 'checkpoint/resnet34-333f7ec4.pth', 83 | 'resnet50': 'checkpoint/resnet50-19c8e357.pth', 84 | } 85 | # directory for saving model/checkpoint 86 | ckpt_save_dir = 'checkpoint/ood_new_model' 87 | 88 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 89 | model_arch = 'resnet50' 90 | nclasses = 2 91 | old_ckpt_path = old_ckpt_path_dict[model_arch] # path to load old model/checkpoint, set it to None if not loading 92 | trainable_params = None # layers or modules set to be trainable. "None" if training all layers 93 | model_name = 'ood' # the prefix of the filename for saving model/checkpoint 94 | return_best = True # whether to return the best model according to the validation metrics 95 | if_early_stop = True # whether to stop early after validation metrics doesn't improve for definite number of epochs 96 | input_size = 299 # image size fed into the model 97 | imbalance_rate = 1.0 # weight given to the positive (rarer) samples in loss function 98 | learning_rate = 0.0001 # learning rate 99 | # weight_decay = 0 # l2 regularization coefficient 100 | batch_size = 64 101 | num_epochs = 100 # number of epochs to train 102 | lr_decay_rate = 0.95 # learning rate decay rate for each decay step 103 | # lr_decay_epochs = 10 # number of epochs for one learning rate decay 104 | early_stop_epochs = 10 # after validation metrics doesn't improve for "early_stop_epochs" epochs, stop the training. 105 | save_epochs = 50 # save the model/checkpoint every "save_epochs" epochs 106 | # threshold = 0.2 # threshold probability to identify am image as positive 107 | ib1 = 1 # weight for imbalance class 108 | 109 | # hyperparamters to tune 110 | lr_list = [0.00001, 0.0001, 0.001] # learning rates 111 | lr_decay_epochs_list = [10, 4] # learning rate decay epochs 112 | weight_decay_list = [0, 0.001] # weight decay 113 | threshold_list = np.linspace(0.0, 1.0, 101).tolist() 114 | 115 | 116 | def RandomRotationNew(image): 117 | angle = random.choice([0, 180]) 118 | image = TF.rotate(image, angle) 119 | return image 120 | 121 | 122 | class MyCrop: 123 | def __init__(self, top, left, height, width): 124 | self.top = top 125 | self.left = left 126 | self.height = height 127 | self.width = width 128 | 129 | def __call__(self, img): 130 | return TF.crop(img, self.top, self.left, self.height, self.width) 131 | 132 | 133 | def only_train(model, trainable_params): 134 | """trainable_params: The list of parameters and modules that are set to be trainable. 135 | Set require_grad = False for all those parameters not in the trainable_params""" 136 | print('Only the following layers:') 137 | for name, p in model.named_parameters(): 138 | p.requires_grad = False 139 | for target in trainable_params: 140 | if target == name or target in name: 141 | p.requires_grad = True 142 | print(' ' + name) 143 | break 144 | 145 | 146 | def metrics(stats): 147 | """stats: {'TP': TP, 'FP': FP, 'TN': TN, 'FN': FN} 148 | return: must be a single number """ 149 | precision = (stats['TP'] + 0.00001) * 1.0 / (stats['TP'] + stats['FP'] + 0.00001) 150 | recall = (stats['TP'] + 0.00001) * 1.0 / (stats['TP'] + stats['FN'] + 0.00001) 151 | spec = (stats['TN'] + 0.00001) * 1.0 / (stats['TN'] + stats['FP'] + 0.00001) 152 | sens = (stats['TP'] + 0.00001) * 1.0 / (stats['TP'] + stats['FN'] + 0.00001) 153 | hm1 = 2.0 * precision * recall / (precision + recall + 1e-7) 154 | hm2 = 2.0 * spec * sens / (spec + sens + 1e-7) 155 | return hm1, hm2 156 | 157 | 158 | def train_model(model, model_name, dataloaders, criterion, optimizer, metrics, num_epochs, training_log=None, 159 | verbose=True, return_best=True, if_early_stop=True, early_stop_epochs=10, scheduler=None, 160 | save_dir=None, save_epochs=5): 161 | since = time.time() 162 | if not training_log: 163 | training_log = dict() 164 | training_log['train_loss_history'] = [] 165 | training_log['val_loss_history'] = [] 166 | training_log['val_metric_value_history'] = [] 167 | training_log['epoch_best_threshold_history'] = [] 168 | training_log['current_epoch'] = -1 169 | current_epoch = training_log['current_epoch'] + 1 170 | 171 | best_model_wts = copy.deepcopy(model.state_dict()) 172 | best_optimizer_wts = copy.deepcopy(optimizer.state_dict()) 173 | best_log = copy.deepcopy(training_log) 174 | 175 | best_metric_value = -np.inf 176 | best_threshold_1 = 0 177 | best_threshold_2 = 0 178 | nodecrease = 0 # to count the epochs that val loss doesn't decrease 179 | early_stop = False 180 | 181 | for epoch in range(current_epoch, current_epoch + num_epochs): 182 | if verbose: 183 | print('Epoch {}/{}'.format(epoch, num_epochs - 1)) 184 | print('-' * 10) 185 | 186 | # Each epoch has a training and validation phase 187 | for phase in ['train', 'val']: 188 | if phase == 'train': 189 | model.train() # Set model to training mode 190 | else: 191 | model.eval() # Set model to evaluate mode 192 | 193 | running_loss = 0.0 194 | 195 | stats1 = {x: {'TP': 0, 'FP': 0, 'TN': 0, 'FN': 0} for x in threshold_list} 196 | stats2 = {x: {'TP': 0, 'FP': 0, 'TN': 0, 'FN': 0} for x in threshold_list} 197 | 198 | # Iterate over data. 199 | for inputs, labels in tqdm(dataloaders[phase]): 200 | inputs = inputs.to(device) 201 | labels = labels.to(device) 202 | 203 | # zero the parameter gradients 204 | optimizer.zero_grad() 205 | 206 | # forward 207 | # track history if only in train 208 | with torch.set_grad_enabled(phase == 'train'): 209 | # Get model outputs and calculate loss 210 | if phase == 'train': 211 | if model_arch != 'inception': 212 | outputs = model(inputs) 213 | loss = criterion(outputs, labels) 214 | else: 215 | outputs, aux_outputs = model(inputs) 216 | loss1 = criterion(outputs, labels) 217 | loss2 = criterion(aux_outputs, labels) 218 | loss = loss1 + 0.4 * loss2 219 | 220 | # backward + optimize only if in training phase 221 | loss.backward() 222 | optimizer.step() 223 | 224 | else: 225 | outputs = model(inputs) 226 | loss = criterion(outputs, labels) 227 | # val phase: calculate metrics under different threshold 228 | prob = torch.sigmoid(outputs) 229 | 230 | labels1 = labels[:, 0] 231 | labels2 = labels[:, 1] 232 | 233 | for threshold1 in threshold_list: 234 | preds1 = prob[:, 0] >= threshold1 235 | stats1[threshold1]['TP'] += torch.sum((preds1 == 1) * (labels1 == 1)).cpu().item() 236 | stats1[threshold1]['TN'] += torch.sum((preds1 == 0) * (labels1 == 0)).cpu().item() 237 | stats1[threshold1]['FP'] += torch.sum((preds1 == 1) * (labels1 == 0)).cpu().item() 238 | stats1[threshold1]['FN'] += torch.sum((preds1 == 0) * (labels1 == 1)).cpu().item() 239 | 240 | for threshold2 in threshold_list: 241 | preds2 = prob[:, 1] >= threshold2 242 | stats2[threshold2]['TP'] += torch.sum((preds2 == 1) * (labels2 == 1)).cpu().item() 243 | stats2[threshold2]['TN'] += torch.sum((preds2 == 0) * (labels2 == 0)).cpu().item() 244 | stats2[threshold2]['FP'] += torch.sum((preds2 == 1) * (labels2 == 0)).cpu().item() 245 | stats2[threshold2]['FN'] += torch.sum((preds2 == 0) * (labels2 == 1)).cpu().item() 246 | 247 | # loss accumulation 248 | running_loss += loss.item() * inputs.size(0) 249 | 250 | training_log['current_epoch'] = epoch 251 | epoch_loss = running_loss / len(dataloaders[phase].dataset) 252 | 253 | if phase == 'train': 254 | training_log['train_loss_history'].append(epoch_loss) 255 | if scheduler is not None: 256 | scheduler.step() 257 | if verbose: 258 | print('{} Loss: {:.4f}'.format(phase, epoch_loss)) 259 | 260 | if phase == 'val': 261 | epoch_best_threshold_1 = 0.0 262 | epoch_best_threshold_2 = 0.0 263 | epoch_max_metrics_1 = 0.0 264 | epoch_max_metrics_2 = 0.0 265 | for threshold1 in threshold_list: 266 | metric_value_1, _ = metrics(stats1[threshold1]) 267 | if metric_value_1 > epoch_max_metrics_1: 268 | epoch_best_threshold_1 = threshold1 269 | epoch_max_metrics_1 = metric_value_1 270 | 271 | for threshold2 in threshold_list: 272 | _, metric_value_2 = metrics(stats2[threshold2]) 273 | if metric_value_2 > epoch_max_metrics_2: 274 | epoch_best_threshold_2 = threshold2 275 | epoch_max_metrics_2 = metric_value_2 276 | 277 | epoch_max_metrics = 2.0 * epoch_max_metrics_1 * (epoch_max_metrics_2**2) / (epoch_max_metrics_1 + (epoch_max_metrics_2**2)) 278 | 279 | recall = (stats1[epoch_best_threshold_1]['TP'] + 0.00001) * 1.0 / ( 280 | stats1[epoch_best_threshold_1]['TP'] + stats1[epoch_best_threshold_1]['FN'] + 0.00001) 281 | precision = (stats1[epoch_best_threshold_1]['TP'] + 0.00001) * 1.0 / ( 282 | stats1[epoch_best_threshold_1]['TP'] + stats1[epoch_best_threshold_1]['FP'] + 0.00001) 283 | 284 | spec = (stats2[epoch_best_threshold_2]['TN'] + 0.00001) * 1.0 / ( 285 | stats2[epoch_best_threshold_2]['TN'] + stats2[epoch_best_threshold_2]['FP'] + 0.00001) 286 | sens = (stats2[epoch_best_threshold_2]['TP'] + 0.00001) * 1.0 / ( 287 | stats2[epoch_best_threshold_2]['TP'] + stats2[epoch_best_threshold_2]['FN'] + 0.00001) 288 | 289 | if verbose: 290 | print('{} Loss: {:.4f} Metrics: {:.4f} Threshold1: {:.4f} Threshold2: {:.4f} Recall: {:.4f} Precision: {:.4f} Sensitivity: {:.4f} Specificity: {:.4f}'.format(phase, epoch_loss, 291 | epoch_max_metrics, epoch_best_threshold_1, epoch_best_threshold_2, recall, precision, sens, spec)) 292 | 293 | training_log['val_metric_value_history'].append(epoch_max_metrics) 294 | training_log['val_loss_history'].append(epoch_loss) 295 | training_log['epoch_best_threshold_history'].append([epoch_best_threshold_1, epoch_best_threshold_2]) 296 | 297 | # deep copy the model 298 | if epoch_max_metrics > best_metric_value: 299 | best_metric_value = epoch_max_metrics 300 | best_threshold_1 = epoch_best_threshold_1 301 | best_threshold_2 = epoch_best_threshold_2 302 | best_model_wts = copy.deepcopy(model.state_dict()) 303 | best_optimizer_wts = copy.deepcopy(optimizer.state_dict()) 304 | best_log = copy.deepcopy(training_log) 305 | nodecrease = 0 306 | else: 307 | nodecrease += 1 308 | 309 | if nodecrease >= early_stop_epochs: 310 | early_stop = True 311 | 312 | if save_dir and epoch % save_epochs == 0 and epoch > 0: 313 | checkpoint = { 314 | 'model_state_dict': model.state_dict(), 315 | 'optimizer_state_dict': optimizer.state_dict(), 316 | 'training_log': training_log 317 | } 318 | torch.save(checkpoint, 319 | os.path.join(save_dir, model_name + '_' + str(training_log['current_epoch']) + '.tar')) 320 | 321 | if if_early_stop and early_stop: 322 | print('Early stopped!') 323 | break 324 | 325 | time_elapsed = time.time() - since 326 | print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60)) 327 | print('Best validation metric value: {:4f}'.format(best_metric_value)) 328 | # print('Best validation threshold 1: {:4f}'.format(best_threshold_1)) 329 | 330 | # load best model weights 331 | if return_best: 332 | model.load_state_dict(best_model_wts) 333 | optimizer.load_state_dict(best_optimizer_wts) 334 | training_log = best_log 335 | 336 | checkpoint = { 337 | 'model_state_dict': model.state_dict(), 338 | 'optimizer_state_dict': optimizer.state_dict(), 339 | 'training_log': training_log 340 | } 341 | torch.save(checkpoint, 342 | os.path.join(save_dir, model_name + '_' + str(training_log['current_epoch']) + '_last.tar')) 343 | 344 | return model, training_log, best_metric_value, best_threshold_1, best_threshold_2 345 | 346 | 347 | def test_model(model, dataloader, metrics, threshold_list): 348 | stats1 = {x: {'TP': 0, 'FP': 0, 'TN': 0, 'FN': 0} for x in threshold_list} 349 | stats2 = {x: {'TP': 0, 'FP': 0, 'TN': 0, 'FN': 0} for x in threshold_list} 350 | 351 | metric_values = {} 352 | metric_values_1 = {} 353 | metric_values_2 = {} 354 | 355 | model.eval() 356 | for inputs, labels in tqdm(dataloader): 357 | inputs = inputs.to(device) 358 | labels = labels.to(device) 359 | 360 | with torch.set_grad_enabled(False): 361 | outputs = model(inputs) 362 | prob = torch.sigmoid(outputs) 363 | 364 | labels1 = labels[:, 0] 365 | labels2 = labels[:, 1] 366 | 367 | for threshold1 in threshold_list: 368 | preds1 = prob[:, 0] >= threshold1 369 | stats1[threshold1]['TP'] += torch.sum((preds1 == 1) * (labels1 == 1)).cpu().item() 370 | stats1[threshold1]['TN'] += torch.sum((preds1 == 0) * (labels1 == 0)).cpu().item() 371 | stats1[threshold1]['FP'] += torch.sum((preds1 == 1) * (labels1 == 0)).cpu().item() 372 | stats1[threshold1]['FN'] += torch.sum((preds1 == 0) * (labels1 == 1)).cpu().item() 373 | 374 | for threshold2 in threshold_list: 375 | preds2 = prob[:, 1] >= threshold2 376 | stats2[threshold2]['TP'] += torch.sum((preds2 == 1) * (labels2 == 1)).cpu().item() 377 | stats2[threshold2]['TN'] += torch.sum((preds2 == 0) * (labels2 == 0)).cpu().item() 378 | stats2[threshold2]['FP'] += torch.sum((preds2 == 1) * (labels2 == 0)).cpu().item() 379 | stats2[threshold2]['FN'] += torch.sum((preds2 == 0) * (labels2 == 1)).cpu().item() 380 | 381 | for threshold1 in threshold_list: 382 | for threshold2 in threshold_list: 383 | metric_value_1, _ = metrics(stats1[threshold1]) 384 | _, metric_value_2 = metrics(stats2[threshold2]) 385 | metric_values[(threshold1, threshold2)] = 2.0 * metric_value_1 * (metric_value_2**2) / (metric_value_1 + (metric_value_2**2) + 1e-8) 386 | metric_values_1[threshold1] = metric_value_1 387 | metric_values_2[threshold2] = metric_value_2 388 | return metric_values, metric_values_1, metric_values_2 389 | 390 | 391 | data_transforms = { 392 | 'train': transforms.Compose([ 393 | transforms.Resize((input_size, input_size)), 394 | MyCrop(17, 0, 240, 299), 395 | transforms.Lambda(RandomRotationNew), 396 | transforms.RandomHorizontalFlip(p=0.5), 397 | transforms.RandomVerticalFlip(p=0.5), 398 | # transforms.Resize((input_size, input_size)), 399 | transforms.ToTensor(), 400 | # transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) 401 | ]), 402 | 'val': transforms.Compose([ 403 | transforms.Resize((input_size, input_size)), 404 | MyCrop(17, 0, 240, 299), 405 | # transforms.Resize((input_size, input_size)), 406 | transforms.ToTensor(), 407 | # transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) 408 | ]), 409 | 'test': transforms.Compose([ 410 | transforms.Resize((input_size, input_size)), 411 | MyCrop(17, 0, 240, 299), 412 | # transforms.Resize((input_size, input_size)), 413 | transforms.ToTensor(), 414 | # transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) 415 | ]) 416 | } 417 | 418 | if __name__ == '__main__': 419 | # data 420 | image_datasets = {x: FolderDirsDatasetMultiLabels(dirs_list_dict[x], transform=data_transforms[x]) for x in ['train', 'val', 'test']} 421 | dataloaders_dict = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, 422 | shuffle=True, num_workers=4) for x in ['train', 'val', 'test']} 423 | 424 | print('Training set size: ' + str(len(image_datasets['train']))) 425 | print('Validation set size: ' + str(len(image_datasets['val']))) 426 | print('Test set size: ' + str(len(image_datasets['test']))) 427 | 428 | results_dict = {x: {y: {z: {} for z in weight_decay_list} for y in lr_decay_epochs_list} for x in lr_list} 429 | 430 | if not os.path.exists(ckpt_save_dir): 431 | os.mkdir(ckpt_save_dir) 432 | 433 | # model 434 | for learning_rate in lr_list: 435 | for lr_decay_epochs in lr_decay_epochs_list: 436 | for weight_decay in weight_decay_list: 437 | print('----------------------- ' + 438 | str(learning_rate) + ', ' + 439 | str(lr_decay_epochs) + ', ' + 440 | str(weight_decay) + 441 | ' -----------------------') 442 | if model_arch == 'resnet18': 443 | model = resnet18(num_classes=nclasses) 444 | elif model_arch == 'resnet34': 445 | model = resnet34(num_classes=nclasses) 446 | elif model_arch == 'resnet50': 447 | model = resnet50(num_classes=nclasses) 448 | elif model_arch == 'inception': 449 | model = Inception3(num_classes=nclasses, aux_logits=True, transform_input=False) 450 | else: 451 | raise 452 | optimizer = optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), eps=1e-08, 453 | weight_decay=weight_decay, amsgrad=True) 454 | pos_weight = torch.tensor([ib1, 1], dtype=torch.float).cuda() 455 | loss_fn = nn.BCEWithLogitsLoss(pos_weight=pos_weight) 456 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=lr_decay_epochs, gamma=lr_decay_rate) 457 | 458 | # load old parameters 459 | if old_ckpt_path: 460 | checkpoint = torch.load(old_ckpt_path) 461 | if old_ckpt_path[-4:] == '.tar': # it is a checkpoint dictionary rather than just model parameters 462 | model.load_state_dict(checkpoint['model_state_dict']) 463 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 464 | training_log = checkpoint['training_log'] 465 | 466 | else: 467 | del checkpoint['fc.weight'] 468 | del checkpoint['fc.bias'] 469 | if model_arch == 'inception': 470 | del checkpoint['AuxLogits.fc.weight'] 471 | del checkpoint['AuxLogits.fc.bias'] 472 | model.load_state_dict(checkpoint, strict=False) 473 | training_log = None # start from scratch 474 | 475 | print('Old checkpoint loaded: ' + old_ckpt_path) 476 | 477 | model = model.to(device) 478 | 479 | # fix some layers and make others trainable 480 | if trainable_params: 481 | only_train(model, trainable_params) 482 | 483 | best_model, _, best_metric_value, best_threshold_1, best_threshold_2 = train_model(model, model_name=model_name + '_lr_' + str( 484 | learning_rate) + '_decay_' + str(lr_decay_epochs) + '_wd_' + str(weight_decay), 485 | dataloaders=dataloaders_dict, criterion=loss_fn, 486 | optimizer=optimizer, metrics=metrics, 487 | num_epochs=num_epochs, 488 | training_log=training_log, verbose=True, 489 | return_best=return_best, 490 | if_early_stop=if_early_stop, 491 | early_stop_epochs=early_stop_epochs, 492 | scheduler=scheduler, save_dir=ckpt_save_dir, 493 | save_epochs=save_epochs) 494 | 495 | print('Begin test ...') 496 | test_metric_values, metric_values_1, metric_values_2 = test_model(best_model, dataloaders_dict['test'], metrics, threshold_list=threshold_list) 497 | 498 | best_threshold_test, best_metric_value_test = \ 499 | sorted(list(test_metric_values.items()), key=lambda tup: tup[1], reverse=True)[0] 500 | 501 | results_dict[learning_rate][lr_decay_epochs][weight_decay] = {'best_metrics_val': best_metric_value, 502 | 'best_threshold_val': (best_threshold_1, best_threshold_2), 503 | 'best_metric_test': best_metric_value_test, 504 | 'best_threshold_test': best_threshold_test, 505 | 'test_metrics_with_val_best_threshold': 506 | test_metric_values[(best_threshold_1, best_threshold_2)], 507 | 'test_metrics_1_with_val_best_threshold': 508 | metric_values_1[best_threshold_1], 509 | 'test_metrics_2_with_val_best_threshold': 510 | metric_values_2[best_threshold_2] 511 | } 512 | 513 | with open(join(ckpt_save_dir, 'results_dict.pickle'), 'wb') as f: 514 | pickle.dump(results_dict, f) 515 | 516 | print(results_dict) 517 | -------------------------------------------------------------------------------- /predict_HR.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import division 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | from torch.utils.data import Dataset, DataLoader 7 | import torchvision 8 | from torchvision import datasets, models, transforms, utils 9 | import torchvision.transforms.functional as TF 10 | 11 | from tqdm import tqdm 12 | import numpy as np 13 | import pandas as pd 14 | import pickle 15 | import matplotlib.pyplot as plt 16 | # import skimage 17 | # import skimage.io 18 | # import skimage.transform 19 | from PIL import Image 20 | import time 21 | import os 22 | from os.path import join, exists 23 | import copy 24 | import random 25 | from collections import OrderedDict 26 | from sklearn.metrics import r2_score 27 | 28 | from torch.nn import functional as F 29 | from torchvision.models import Inception3, resnet18, resnet34, resnet50 30 | 31 | from utils.image_dataset import * 32 | from LR_models.siamese_model_rgb import * 33 | 34 | """ 35 | This script is for generating the prediction scores of HR model for images in sequences. 36 | A sequence of images are stored in a folder. An image is named by the year of the image 37 | plus an auxiliary index. E.g., '2007_0.png', '2007_1.png', '2008_0.png'. 38 | """ 39 | 40 | dir_list = ['demo_sequences'] 41 | 42 | root_data_dir = 'data/sequences' 43 | old_ckpt_path = 'checkpoint/HR_decay_10_lr_0.0001_8_last.tar' 44 | result_path = 'results/HR_prob_dict.pickle' 45 | error_list_path = 'results/HR_error_list.pickle' 46 | 47 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 48 | input_size = 299 49 | batch_size = 64 50 | 51 | 52 | class MyCrop: 53 | def __init__(self, top, left, height, width): 54 | self.top = top 55 | self.left = left 56 | self.height = height 57 | self.width = width 58 | 59 | def __call__(self, img): 60 | return TF.crop(img, self.top, self.left, self.height, self.width) 61 | 62 | 63 | class SingleImageDatasetModified(Dataset): 64 | def __init__(self, dir_list, transform, latest_prob_dict): 65 | self.path_list = [] 66 | self.transform = transform 67 | 68 | for subdir in dir_list: 69 | data_dir = join(root_data_dir, subdir) 70 | for folder in os.listdir(data_dir): 71 | idx = folder.split('_')[0] 72 | folder_dir = join(data_dir, folder) 73 | for f in os.listdir(folder_dir): 74 | if not f[-4:] == '.png': 75 | continue 76 | if idx in latest_prob_dict and f in latest_prob_dict[idx]: 77 | continue 78 | self.path_list.append((subdir, folder, f)) 79 | 80 | def __len__(self): 81 | return len(self.path_list) 82 | 83 | def __getitem__(self, index): 84 | subdir, folder, fname = self.path_list[index] 85 | image_path = join(root_data_dir, subdir, folder, fname) 86 | idx = folder.split('_')[0] 87 | img = Image.open(image_path) 88 | if not img.mode == 'RGB': 89 | img = img.convert('RGB') 90 | img = self.transform(img) 91 | return img, idx, fname 92 | 93 | 94 | transform_test = transforms.Compose([ 95 | transforms.Resize(input_size), 96 | transforms.ToTensor(), 97 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) 98 | ]) 99 | 100 | 101 | if __name__ == '__main__': 102 | # load existing prob dict or initialize a new one 103 | if exists(result_path): 104 | with open(result_path, 'rb') as f: 105 | prob_dict = pickle.load(f) 106 | else: 107 | prob_dict = {} 108 | 109 | # load existing error list or initialize a new one 110 | if exists(error_list_path): 111 | with open(error_list_path, 'rb') as f: 112 | error_list = pickle.load(f) 113 | else: 114 | error_list = [] 115 | 116 | # dataloader 117 | dataset_pred = SingleImageDatasetModified(dir_list, transform=transform_test, latest_prob_dict=prob_dict) 118 | print('Dataset size: ' + str(len(dataset_pred))) 119 | dataloader_pred = DataLoader(dataset_pred, batch_size=batch_size, shuffle=False, num_workers=4) 120 | 121 | # model 122 | model = Inception3(num_classes=2, aux_logits=True, transform_input=False) 123 | model = model.to(device) 124 | 125 | # load old parameters 126 | checkpoint = torch.load(old_ckpt_path, map_location=device) 127 | if old_ckpt_path[-4:] == '.tar': # it is a checkpoint dictionary rather than just model parameters 128 | model.load_state_dict(checkpoint['model_state_dict']) 129 | else: 130 | model.load_state_dict(checkpoint) 131 | print('Old checkpoint loaded: ' + old_ckpt_path) 132 | 133 | model.eval() 134 | # run 135 | count = 0 136 | for inputs, idx_list, fname_list in tqdm(dataloader_pred): 137 | try: 138 | inputs = inputs.to(device) 139 | with torch.set_grad_enabled(False): 140 | outputs = model(inputs) 141 | prob = F.softmax(outputs, dim=1) 142 | pos_prob_list = prob[:, 1].cpu().numpy() 143 | for i in range(len(idx_list)): 144 | idx = idx_list[i] 145 | fname = fname_list[i] 146 | pos_prob = pos_prob_list[i] 147 | 148 | if not idx in prob_dict: 149 | prob_dict[idx] = {} 150 | prob_dict[idx][fname] = pos_prob 151 | 152 | except: # take a note on the batch that causes error 153 | error_list.append((idx_list, fname_list)) 154 | if count % 200 == 0: 155 | with open(join(result_path), 'wb') as f: 156 | pickle.dump(prob_dict, f) 157 | with open(join(error_list_path), 'wb') as f: 158 | pickle.dump(error_list, f) 159 | count += 1 160 | 161 | with open(join(result_path), 'wb') as f: 162 | pickle.dump(prob_dict, f) 163 | with open(join(error_list_path), 'wb') as f: 164 | pickle.dump(error_list, f) 165 | 166 | print('Done!') 167 | 168 | 169 | 170 | 171 | 172 | -------------------------------------------------------------------------------- /predict_LR_rgb.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import division 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | from torch.utils.data import Dataset, DataLoader 7 | import torchvision 8 | from torchvision import datasets, models, transforms, utils 9 | import torchvision.transforms.functional as TF 10 | 11 | from tqdm import tqdm 12 | import numpy as np 13 | import pandas as pd 14 | import pickle 15 | import matplotlib.pyplot as plt 16 | # import skimage 17 | # import skimage.io 18 | # import skimage.transform 19 | from PIL import Image 20 | import time 21 | import os 22 | from os.path import join, exists 23 | import copy 24 | import random 25 | from collections import OrderedDict 26 | from sklearn.metrics import r2_score 27 | 28 | from torch.nn import functional as F 29 | from torchvision.models import Inception3 30 | 31 | from utils.image_dataset import * 32 | from LR_models.siamese_model_rgb import * 33 | 34 | """ 35 | This script is for generating the prediction scores of LR model for images in sequences. 36 | A sequence of images are stored in a folder. An image is named by the year of the image 37 | plus an auxiliary index. E.g., '2007_0.png', '2007_1.png', '2008_0.png'. 38 | """ 39 | 40 | dir_list = ['demo_sequences'] 41 | 42 | root_data_dir = 'data/sequences' 43 | old_ckpt_path = 'checkpoint/LR_nconvs_3_depth_128_nfilters_512_33_last.tar' 44 | result_path = 'results/LR_prob_dict.pickle' 45 | error_list_path = 'results/LR_error_list.pickle' 46 | anchor_images_dict_path = 'results/anchor_images_dict.pickle' 47 | 48 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 49 | backbone = 'resnet34' 50 | input_size = 299 51 | batch_size = 64 52 | 53 | transform_test = transforms.Compose([ 54 | transforms.Resize((input_size, input_size)), 55 | # transforms.Lambda(mask_image_info), 56 | transforms.ToTensor() 57 | ]) 58 | 59 | 60 | class ImagePairDatasetModified(Dataset): 61 | def __init__(self, dir_list, anchor_images_dict, transform, latest_prob_dict): 62 | self.couple_list = [] 63 | self.transform = transform 64 | 65 | for subdir in dir_list: 66 | data_dir = join(root_data_dir, subdir) 67 | for folder in os.listdir(data_dir): 68 | idx = folder.split('_')[0] 69 | folder_dir = join(data_dir, folder) 70 | if idx not in anchor_images_dict: 71 | continue 72 | anchor_images = anchor_images_dict[idx] 73 | for anchor_f in anchor_images: 74 | for tar_f in os.listdir(folder_dir): 75 | if not tar_f[-4:] == '.png': 76 | continue 77 | if idx in latest_prob_dict and anchor_f in latest_prob_dict[idx] and tar_f in latest_prob_dict[idx][anchor_f]: 78 | continue 79 | self.couple_list.append((subdir, folder, anchor_f, tar_f)) 80 | 81 | def __len__(self): 82 | return len(self.couple_list) 83 | 84 | def __getitem__(self, index): 85 | subdir, folder, anchor_f, tar_f = self.couple_list[index] 86 | ref_img_path = join(root_data_dir, subdir, folder, anchor_f) 87 | tar_img_path = join(root_data_dir, subdir, folder, tar_f) 88 | idx = folder.split('_')[0] 89 | 90 | img_ref = Image.open(ref_img_path) 91 | img_tar = Image.open(tar_img_path) 92 | if not img_ref.mode == 'RGB': 93 | img_ref = img_ref.convert('RGB') 94 | if not img_tar.mode == 'RGB': 95 | img_tar = img_tar.convert('RGB') 96 | 97 | img_ref = self.transform(img_ref) 98 | img_tar = self.transform(img_tar) 99 | 100 | return img_ref, img_tar, idx, anchor_f, tar_f 101 | 102 | 103 | if __name__ == '__main__': 104 | # load anchor_images_dict 105 | with open(anchor_images_dict_path, 'rb') as f: 106 | anchor_images_dict = pickle.load(f) 107 | 108 | # load existing prob dict or initialize a new one 109 | if exists(result_path): 110 | with open(result_path, 'rb') as f: 111 | prob_dict = pickle.load(f) 112 | else: 113 | prob_dict = {} 114 | 115 | # load existing error list or initialize a new one 116 | if exists(error_list_path): 117 | with open(error_list_path, 'rb') as f: 118 | error_list = pickle.load(f) 119 | else: 120 | error_list = [] 121 | 122 | # dataloader 123 | dataset_pred = ImagePairDatasetModified(dir_list, anchor_images_dict, transform=transform_test, latest_prob_dict=prob_dict) 124 | print('Dataset size: ' + str(len(dataset_pred))) 125 | dataloader_pred = DataLoader(dataset_pred, batch_size=batch_size, shuffle=False, num_workers=4) 126 | 127 | # model 128 | model = psn_depthwise_cc_layerwise_3layers_l234(backbone=backbone, nconvs=3, depth=128, nfilters=512, kernel_size=3) 129 | model = model.to(device) 130 | 131 | # load old parameters 132 | checkpoint = torch.load(old_ckpt_path, map_location=device) 133 | if old_ckpt_path[-4:] == '.tar': # it is a checkpoint dictionary rather than just model parameters 134 | model.load_state_dict(checkpoint['model_state_dict']) 135 | else: 136 | model.load_state_dict(checkpoint) 137 | print('Old checkpoint loaded: ' + old_ckpt_path) 138 | 139 | model.eval() 140 | 141 | # run 142 | count = 0 143 | for inputs_ref, inputs_tar, idx_list, anchor_f_list, tar_f_list in tqdm(dataloader_pred): 144 | try: 145 | inputs_ref = inputs_ref.to(device) 146 | inputs_tar = inputs_tar.to(device) 147 | with torch.set_grad_enabled(False): 148 | outputs = model(inputs_tar, inputs_ref) 149 | prob = F.softmax(outputs, dim=1) 150 | pos_prob_list = prob[:, 1].cpu().numpy() 151 | for i in range(len(idx_list)): 152 | idx = idx_list[i] 153 | anchor_f = anchor_f_list[i] 154 | tar_f = tar_f_list[i] 155 | pos_prob = pos_prob_list[i] 156 | 157 | if not idx in prob_dict: 158 | prob_dict[idx] = {} 159 | if not anchor_f in prob_dict[idx]: 160 | prob_dict[idx][anchor_f] = {} 161 | 162 | prob_dict[idx][anchor_f][tar_f] = pos_prob 163 | 164 | except: # take a note on the batch that causes error 165 | error_list.append((idx_list, anchor_f_list, tar_f_list)) 166 | 167 | if count % 400 == 0: 168 | with open(result_path, 'wb') as f: 169 | pickle.dump(prob_dict, f) 170 | with open(error_list_path, 'wb') as f: 171 | pickle.dump(error_list, f) 172 | count += 1 173 | 174 | with open(result_path, 'wb') as f: 175 | pickle.dump(prob_dict, f) 176 | with open(error_list_path, 'wb') as f: 177 | pickle.dump(error_list, f) 178 | 179 | print('Done!') 180 | -------------------------------------------------------------------------------- /predict_installation_year_from_image_sequences.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os\n", 10 | "import os.path\n", 11 | "import pickle\n", 12 | "import shutil\n", 13 | "import pandas as pd\n", 14 | "import random\n", 15 | "import math\n", 16 | "import matplotlib.pyplot as plt\n", 17 | "from os.path import join, exists\n", 18 | "import numpy as np\n", 19 | "from tqdm import tqdm\n", 20 | "import random\n", 21 | "import copy\n", 22 | "%matplotlib inline" 23 | ] 24 | }, 25 | { 26 | "cell_type": "markdown", 27 | "metadata": {}, 28 | "source": [ 29 | "This notebook takes the prediction scores of HR model, LR model, and blur detection model (ood model) as inputs, and outputs the installation year prediction for each image sequence." 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": 2, 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "dir_list = ['demo_sequences']" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": 3, 44 | "metadata": {}, 45 | "outputs": [], 46 | "source": [ 47 | "root_data_dir = 'data/sequences'" 48 | ] 49 | }, 50 | { 51 | "cell_type": "markdown", 52 | "metadata": {}, 53 | "source": [ 54 | "# 1. Load prob dicts" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 4, 60 | "metadata": {}, 61 | "outputs": [], 62 | "source": [ 63 | "with open('results/HR_prob_dict.pickle', 'rb') as f:\n", 64 | " HR_prob_dict = pickle.load(f)" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": 5, 70 | "metadata": {}, 71 | "outputs": [], 72 | "source": [ 73 | "with open('results/LR_prob_dict.pickle', 'rb') as f:\n", 74 | " LR_prob_dict = pickle.load(f)" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 6, 80 | "metadata": {}, 81 | "outputs": [], 82 | "source": [ 83 | "with open('results/ood_prob_dict.pickle', 'rb') as f:\n", 84 | " ood_prob_dict = pickle.load(f)" 85 | ] 86 | }, 87 | { 88 | "cell_type": "markdown", 89 | "metadata": {}, 90 | "source": [ 91 | "# 2. Installation year detection" 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": 8, 97 | "metadata": {}, 98 | "outputs": [], 99 | "source": [ 100 | "LR_threshold = 0.97\n", 101 | "blur_threshold = 0.29\n", 102 | "ood_threshold = 0.09" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": 9, 108 | "metadata": {}, 109 | "outputs": [], 110 | "source": [ 111 | "# Given LR prediction score, and the info of OOD (blur), return the year of installation.\n", 112 | "# Using OOD prediction (multiclass: [whether it is ood, whether it is HR]).\n", 113 | "# In this function, reference list (HR_prob > thres) is assumed to be the key values of LR_prob_dict.\n", 114 | "def hybrid_model_5(LR_prob_dict, blur_info, LR_threshold=0.5, ood_threshold=0.5, blur_threshold=0.5):\n", 115 | " \"\"\"\n", 116 | " LR_prob_dict: key1: anchor_filename, key2: target_filename, value: prob produced by LR model\n", 117 | " blur_info: key: filename, value: an array of two scores (OOD score and blur score)\n", 118 | " LR_threshold: to determine whether a LR image is positive or not.\n", 119 | " ood_threshold: to determint whether a image is out-of-distribution (\"impossible to detect\") or not.\n", 120 | " blur_threshold: to determine whether a image is HR or LR.\n", 121 | " \"\"\" \n", 122 | " def is_anchor_candidate(f):\n", 123 | " \"\"\" Determine whether an image can be a candidate of the \"positive anchor\" based on its blur \n", 124 | " score and OOD score. \"\"\"\n", 125 | " if blur_info[f][1] >= blur_threshold and blur_info[f][0] >= ood_threshold:\n", 126 | " return True # HR\n", 127 | " else:\n", 128 | " return False # LR or OOD\n", 129 | " \n", 130 | " # reference list: a list of image filenames with its HR prediction score >= HR_threshold\n", 131 | " reference_list = sorted(LR_prob_dict.keys()) # sorted in the time order\n", 132 | " \n", 133 | " # determine the \"positive anchor\"\n", 134 | " selected_anchors = [f for f in reference_list if is_anchor_candidate(f)]\n", 135 | " if selected_anchors:\n", 136 | " positive_anchor = selected_anchors[0] # use the earliest anchor image as the \"positive anchor\"\n", 137 | " else:\n", 138 | " positive_anchor = reference_list[-1]\n", 139 | " \n", 140 | " # determine the first target (LR) that surpass the threshold based on all referenced anchors\n", 141 | " for target in sorted(LR_prob_dict[positive_anchor].keys()): # go through all images\n", 142 | " if is_anchor_candidate(target): # skip those images with is HR\n", 143 | " continue\n", 144 | " if int(target.split('_')[0]) > int(positive_anchor.split('_')[0]): # skip those images later than positive anchor\n", 145 | " continue\n", 146 | " if blur_info[target][0] < ood_threshold: # don't consider OOD images but record them\n", 147 | " continue\n", 148 | " for ref in reference_list:\n", 149 | " if LR_prob_dict[ref][target] > LR_threshold:\n", 150 | " return max(min(2017, int(target.split('_')[0])), 2005), positive_anchor\n", 151 | " \n", 152 | " return max(min(2017, int(positive_anchor.split('_')[0])), 2005), positive_anchor" 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": 10, 158 | "metadata": {}, 159 | "outputs": [], 160 | "source": [ 161 | "# To gather all \"critial\" years that are missing but may change the year prediction.\n", 162 | "# \"critical\" means that it is RIGHT before the predicted installation year\n", 163 | "def backtrack_missing_critical_years(LR_prob_dict, \n", 164 | " blur_info, \n", 165 | " positive_anchor, \n", 166 | " installation_year,\n", 167 | " LR_threshold,\n", 168 | " ood_threshold,\n", 169 | " blur_threshold):\n", 170 | " \"\"\"\n", 171 | " LR_prob_dict: key1: anchor_filename, key2: target_filename, value: prob produced by LR model\n", 172 | " blur_info: key: filename, value: an array of two scores (OOD score and blur score)\n", 173 | " ood_images: a list of image filenames which are identified as OOD and thus can be regarded as missing\n", 174 | " positive_anchor: the anchor image filename which is the earliest HR positive sample\n", 175 | " installation_year: the predicted year of installation\n", 176 | " LR_threshold: to determine whether a LR image is positive or not\n", 177 | " ood_threshold: to determint whether a image is out-of-distribution (\"impossible to detect\") or not.\n", 178 | " blur_threshold: to determine whether a image is HR or LR.\n", 179 | " \"\"\"\n", 180 | " all_images = sorted(LR_prob_dict[positive_anchor].keys()) # all image filenames in that sequence in the time order\n", 181 | " \n", 182 | " # reference list: a list of image filenames with its HR prediction score >= HR_threshold\n", 183 | " reference_list = set(sorted(LR_prob_dict.keys())) # sorted in the time order\n", 184 | " \n", 185 | " all_downloaded_years = {} # Note: only consider those years no later than installation_year\n", 186 | " for f in all_images:\n", 187 | " year = int(f.split('_')[0])\n", 188 | " if blur_info[f][0] >= ood_threshold or f in reference_list: # OOD images are regarded as missing\n", 189 | " if year not in all_downloaded_years:\n", 190 | " all_downloaded_years[year] = []\n", 191 | " all_downloaded_years[year].append(f)\n", 192 | " \n", 193 | " missing_critial_years = []\n", 194 | " # backtracking\n", 195 | " curr_year = installation_year - 1\n", 196 | " while curr_year >= 2005 and curr_year not in all_downloaded_years:\n", 197 | " missing_critial_years.append(curr_year)\n", 198 | " curr_year -= 1\n", 199 | " \n", 200 | " if not missing_critial_years: # no missing\n", 201 | " return missing_critial_years\n", 202 | " \n", 203 | " if installation_year not in all_downloaded_years: # it indicates that the actual predicted year is 2018 but restricted to 2017\n", 204 | " assert installation_year == 2017\n", 205 | " return missing_critial_years + [2017]\n", 206 | " \n", 207 | "# if len(all_downloaded_years[installation_year]) == 1: # only one image in that year\n", 208 | "# return missing_critial_years\n", 209 | " \n", 210 | " for f in all_downloaded_years[installation_year]: \n", 211 | " # if any one of the images in the installtion year is negative (HR negative and LR negative), \n", 212 | " # then we can infer one sample is positive and another is negative in that year, \n", 213 | " # thus the solar panel must be installed in that year\n", 214 | " # then there is no missing critical year\n", 215 | " if blur_info[f][1] >= blur_threshold and f not in reference_list:\n", 216 | " return []\n", 217 | " if blur_info[f][1] < blur_threshold and f not in reference_list and all([LR_prob_dict[x][f] < LR_threshold for x in reference_list]):\n", 218 | " return []\n", 219 | " \n", 220 | " return missing_critial_years # a list of missing critial years" 221 | ] 222 | }, 223 | { 224 | "cell_type": "code", 225 | "execution_count": null, 226 | "metadata": {}, 227 | "outputs": [], 228 | "source": [ 229 | "installation_year_dict = {} # sequence idx -> predicted installation year\n", 230 | "missing_years_dict = {} # sequence idx -> a list of missing critial years\n", 231 | "for idx in tqdm(HR_prob_dict):\n", 232 | " LR_prob_dict_sub = LR_prob_dict[idx]\n", 233 | " blur_info = ood_prob_dict[idx]\n", 234 | " installation_year, positive_anchor = hybrid_model_5(LR_prob_dict_sub, blur_info, LR_threshold, \n", 235 | " ood_threshold, blur_threshold)\n", 236 | " missing_years = backtrack_missing_critical_years(LR_prob_dict_sub, blur_info, positive_anchor, \n", 237 | " installation_year,\n", 238 | " LR_threshold, ood_threshold, blur_threshold)\n", 239 | " installation_year_dict[int(idx)] = installation_year\n", 240 | " if missing_years:\n", 241 | "# if not installation_year in missing_years:\n", 242 | "# missing_years_dict[int(idx)] = missing_years + [installation_year]\n", 243 | "# else:\n", 244 | " missing_years_dict[int(idx)] = missing_years\n", 245 | "print(len(installation_year_dict))\n", 246 | "print(len(missing_years_dict))" 247 | ] 248 | }, 249 | { 250 | "cell_type": "code", 251 | "execution_count": 53, 252 | "metadata": {}, 253 | "outputs": [], 254 | "source": [ 255 | "with open('results/installation_year_prediction_dict.pickle', 'wb') as f:\n", 256 | " pickle.dump(installation_year_dict, f) \n", 257 | "with open('results/missing_years_dict.pickle', 'wb') as f:\n", 258 | " pickle.dump(missing_years_dict, f)" 259 | ] 260 | }, 261 | { 262 | "cell_type": "code", 263 | "execution_count": null, 264 | "metadata": {}, 265 | "outputs": [], 266 | "source": [] 267 | } 268 | ], 269 | "metadata": { 270 | "kernelspec": { 271 | "display_name": "Environment (conda_tensorflow_p36)", 272 | "language": "python", 273 | "name": "conda_tensorflow_p36" 274 | }, 275 | "language_info": { 276 | "codemirror_mode": { 277 | "name": "ipython", 278 | "version": 3 279 | }, 280 | "file_extension": ".py", 281 | "mimetype": "text/x-python", 282 | "name": "python", 283 | "nbconvert_exporter": "python", 284 | "pygments_lexer": "ipython3", 285 | "version": "3.6.5" 286 | } 287 | }, 288 | "nbformat": 4, 289 | "nbformat_minor": 4 290 | } 291 | -------------------------------------------------------------------------------- /predict_ood_multilabels.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import division 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | from torch.utils.data import Dataset, DataLoader 7 | import torchvision 8 | from torchvision import datasets, models, transforms, utils 9 | import torchvision.transforms.functional as TF 10 | 11 | from tqdm import tqdm 12 | import numpy as np 13 | import pandas as pd 14 | import pickle 15 | import matplotlib.pyplot as plt 16 | # import skimage 17 | # import skimage.io 18 | # import skimage.transform 19 | from PIL import Image 20 | import time 21 | import os 22 | from os.path import join, exists 23 | import copy 24 | import random 25 | from collections import OrderedDict 26 | from sklearn.metrics import r2_score 27 | 28 | from torch.nn import functional as F 29 | from torchvision.models import Inception3, resnet18, resnet34, resnet50 30 | 31 | from utils.image_dataset import * 32 | from LR_models.siamese_model_rgb import * 33 | 34 | """ 35 | This script is for generating the prediction scores of blur detection model for images in 36 | sequences. A sequence of images are stored in a folder. An image is named by the year of 37 | the image plus an auxiliary index. E.g., '2007_0.png', '2007_1.png', '2008_0.png'. 38 | """ 39 | 40 | dir_list = ['demo_sequences'] 41 | 42 | root_data_dir = 'data/sequences' 43 | old_ckpt_path = 'checkpoint/ood_ib1_0.2_decay_10_wd_0_22_last.tar' 44 | result_path = 'results/ood_prob_dict.pickle' 45 | error_list_path = 'results/ood_error_list.pickle' 46 | model_arch = 'resnet50' 47 | 48 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 49 | input_size = 299 50 | batch_size = 64 51 | 52 | 53 | class MyCrop: 54 | def __init__(self, top, left, height, width): 55 | self.top = top 56 | self.left = left 57 | self.height = height 58 | self.width = width 59 | 60 | def __call__(self, img): 61 | return TF.crop(img, self.top, self.left, self.height, self.width) 62 | 63 | 64 | class SingleImageDatasetModified(Dataset): 65 | def __init__(self, dir_list, transform, latest_prob_dict): 66 | self.path_list = [] 67 | self.transform = transform 68 | 69 | for subdir in dir_list: 70 | data_dir = join(root_data_dir, subdir) 71 | for folder in os.listdir(data_dir): 72 | idx = folder.split('_')[0] 73 | folder_dir = join(data_dir, folder) 74 | for f in os.listdir(folder_dir): 75 | if not f[-4:] == '.png': 76 | continue 77 | if idx in latest_prob_dict and f in latest_prob_dict[idx]: 78 | continue 79 | self.path_list.append((subdir, folder, f)) 80 | 81 | def __len__(self): 82 | return len(self.path_list) 83 | 84 | def __getitem__(self, index): 85 | subdir, folder, fname = self.path_list[index] 86 | image_path = join(root_data_dir, subdir, folder, fname) 87 | idx = folder.split('_')[0] 88 | img = Image.open(image_path) 89 | if not img.mode == 'RGB': 90 | img = img.convert('RGB') 91 | img = self.transform(img) 92 | return img, idx, fname 93 | 94 | 95 | transform_test = transforms.Compose([ 96 | transforms.Resize((input_size, input_size)), 97 | MyCrop(17, 0, 240, 299), 98 | # transforms.Resize((input_size, input_size)), 99 | transforms.ToTensor(), 100 | # transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) 101 | ]) 102 | 103 | 104 | if __name__ == '__main__': 105 | # load existing prob dict or initialize a new one 106 | if exists(result_path): 107 | with open(result_path, 'rb') as f: 108 | prob_dict = pickle.load(f) 109 | else: 110 | prob_dict = {} 111 | 112 | # load existing error list or initialize a new one 113 | if exists(error_list_path): 114 | with open(error_list_path, 'rb') as f: 115 | error_list = pickle.load(f) 116 | else: 117 | error_list = [] 118 | 119 | # dataloader 120 | dataset_pred = SingleImageDatasetModified(dir_list, transform=transform_test, latest_prob_dict=prob_dict) 121 | print('Dataset size: ' + str(len(dataset_pred))) 122 | dataloader_pred = DataLoader(dataset_pred, batch_size=batch_size, shuffle=False, num_workers=4) 123 | 124 | # model 125 | if model_arch == 'resnet18': 126 | model = resnet18(num_classes=2) 127 | elif model_arch == 'resnet34': 128 | model = resnet34(num_classes=2) 129 | elif model_arch == 'resnet50': 130 | model = resnet50(num_classes=2) 131 | elif model_arch == 'inception': 132 | model = Inception3(num_classes=2, aux_logits=True, transform_input=False) 133 | else: 134 | raise 135 | model = model.to(device) 136 | 137 | # load old parameters 138 | checkpoint = torch.load(old_ckpt_path, map_location=device) 139 | if old_ckpt_path[-4:] == '.tar': # it is a checkpoint dictionary rather than just model parameters 140 | model.load_state_dict(checkpoint['model_state_dict']) 141 | else: 142 | model.load_state_dict(checkpoint) 143 | print('Old checkpoint loaded: ' + old_ckpt_path) 144 | 145 | model.eval() 146 | # run 147 | count = 0 148 | for inputs, idx_list, fname_list in tqdm(dataloader_pred): 149 | try: 150 | inputs = inputs.to(device) 151 | with torch.set_grad_enabled(False): 152 | outputs = model(inputs) 153 | prob = torch.sigmoid(outputs) 154 | prob_list = prob.cpu().numpy() 155 | for i in range(len(idx_list)): 156 | idx = idx_list[i] 157 | fname = fname_list[i] 158 | prob_sample = prob_list[i] 159 | 160 | if not idx in prob_dict: 161 | prob_dict[idx] = {} 162 | prob_dict[idx][fname] = prob_sample 163 | 164 | except: # take a note on the batch that causes error 165 | error_list.append((idx_list, fname_list)) 166 | if count % 200 == 0: 167 | with open(join(result_path), 'wb') as f: 168 | pickle.dump(prob_dict, f) 169 | with open(join(error_list_path), 'wb') as f: 170 | pickle.dump(error_list, f) 171 | count += 1 172 | 173 | with open(join(result_path), 'wb') as f: 174 | pickle.dump(prob_dict, f) 175 | with open(join(error_list_path), 'wb') as f: 176 | pickle.dump(error_list, f) 177 | 178 | print('Done!') 179 | 180 | 181 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm 2 | matplotlib 3 | rpy2 4 | numpy == 1.19.5 5 | scipy == 1.1.0 6 | Pillow == 5.2.0 7 | pandas == 0.24.2 8 | shapely == 1.7.1 9 | geopandas == 0.8.2 10 | geojson == 2.5.0 11 | scikit-learn == 0.22 12 | statsmodels == 0.9.0 13 | plotly == 4.14.3 14 | torch == 1.1.0 15 | torchvision == 0.2.2 -------------------------------------------------------------------------------- /utils/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangzhecheng/DeepSolar_timelapse/b9f0e0e81b901426a166f55a3ba6ab53ea260407/utils/.DS_Store -------------------------------------------------------------------------------- /utils/.ipynb_checkpoints/collect_image_info-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [], 3 | "metadata": {}, 4 | "nbformat": 4, 5 | "nbformat_minor": 2 6 | } 7 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangzhecheng/DeepSolar_timelapse/b9f0e0e81b901426a166f55a3ba6ab53ea260407/utils/__init__.py -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangzhecheng/DeepSolar_timelapse/b9f0e0e81b901426a166f55a3ba6ab53ea260407/utils/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/image_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangzhecheng/DeepSolar_timelapse/b9f0e0e81b901426a166f55a3ba6ab53ea260407/utils/__pycache__/image_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /utils/image_dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import division 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | from torch.utils.data import Dataset, DataLoader 7 | import torchvision 8 | from torchvision import datasets, models, transforms, utils 9 | import torchvision.transforms.functional as TF 10 | 11 | import numpy as np 12 | import pandas as pd 13 | import pickle 14 | import matplotlib.pyplot as plt 15 | # import skimage 16 | # import skimage.io 17 | # import skimage.transform 18 | from PIL import Image 19 | import time 20 | import os 21 | from os.path import join, exists 22 | import copy 23 | import random 24 | from collections import OrderedDict 25 | 26 | 27 | class ImageFolderModified(Dataset): 28 | def __init__(self, root_dir, transform): 29 | self.root_dir = root_dir 30 | self.transform = transform 31 | self.idx2dir = [] 32 | self.path_list = [] 33 | for subdir in sorted(os.listdir(self.root_dir)): 34 | if not os.path.isfile(subdir): 35 | self.idx2dir.append(subdir) 36 | for class_idx, subdir in enumerate(self.idx2dir): 37 | class_dir = os.path.join(self.root_dir, subdir) 38 | for f in os.listdir(class_dir): 39 | if f[-4:] in ['.png', '.jpg', 'JPEG', 'jpeg']: 40 | self.path_list.append([os.path.join(class_dir, f), class_idx]) 41 | 42 | def __len__(self): 43 | return len(self.path_list) 44 | 45 | def __getitem__(self, idx): 46 | img_path, class_idx = self.path_list[idx] 47 | image = Image.open(img_path) 48 | if not image.mode == 'RGB': 49 | image = image.convert('RGB') 50 | image = self.transform(image) 51 | sample = [image, class_idx, img_path] 52 | return sample 53 | 54 | 55 | class BinaryImageFolder(Dataset): 56 | def __init__(self, root_dirs, transform): 57 | """ 58 | :param root_dirs: the list of root directories, the subdirectory of each root directory must be '0' and '1' 59 | :param transform: pytorch transform functions 60 | """ 61 | self.root_dirs = root_dirs 62 | self.transform = transform 63 | self.path_list = [] 64 | for root_dir in self.root_dirs: 65 | assert exists(join(root_dir, '0')) and exists(join(root_dir, '1')) 66 | for class_idx in [0, 1]: 67 | class_dir = join(root_dir, str(class_idx)) 68 | for f in os.listdir(class_dir): 69 | if f[-4:] in ['.png', '.jpg', 'JPEG', 'jpeg']: 70 | self.path_list.append([join(class_dir, f), class_idx]) 71 | 72 | def __len__(self): 73 | return len(self.path_list) 74 | 75 | def __getitem__(self, idx): 76 | img_path, class_idx = self.path_list[idx] 77 | image = Image.open(img_path) 78 | if not image.mode == 'RGB': 79 | image = image.convert('RGB') 80 | image = self.transform(image) 81 | sample = [image, class_idx, img_path] 82 | return sample 83 | 84 | 85 | class ImagePairDataset(Dataset): 86 | """ 87 | :param root_dirs: the list of root directories, the subdirectory of each root directory must be '0' and '1' 88 | :param reference_mapping_paths: the list of path to reference_mapping (dict) 89 | :param is_train: boolean indicating whether it is a training set 90 | :param binary: boolean indicating whether the images are binary 91 | :param transform: pytorch transform functions 92 | """ 93 | def __init__(self, root_dirs, reference_mapping_paths, is_train, binary, transform): 94 | self.couple_list = [] 95 | self.is_train = is_train 96 | self.binary = binary 97 | self.transform = transform 98 | 99 | assert len(root_dirs) == len(reference_mapping_paths) 100 | for i, root_dir in enumerate(root_dirs): 101 | reference_mapping_path = reference_mapping_paths[i] 102 | with open(reference_mapping_path, 'rb') as f: 103 | reference_mapping = pickle.load(f) 104 | 105 | x = root_dir.split('/')[-1] 106 | assert x in ['train', 'val', 'test'] 107 | 108 | for class_idx in [0, 1]: 109 | class_dir = join(root_dir, str(class_idx)) 110 | for f in os.listdir(class_dir): 111 | if f[-4:] in ['.png', '.jpg', 'JPEG', 'jpeg']: 112 | target_path = join(class_dir, f) 113 | target_subpath = join(x, str(class_idx), f) 114 | if target_subpath in reference_mapping: 115 | for ref_subpath in reference_mapping[target_subpath]: 116 | ref_path = root_dir.replace('/' + x, '/' + ref_subpath) 117 | if exists(ref_path): 118 | self.couple_list.append((ref_path, target_path, class_idx)) 119 | 120 | def __len__(self): 121 | return len(self.couple_list) 122 | 123 | def __getitem__(self, index): 124 | each_couple_list = self.couple_list[index] 125 | img_ref = Image.open(each_couple_list[0]) 126 | img_tar = Image.open(each_couple_list[1]) 127 | if not self.binary: 128 | if not img_ref.mode == 'RGB': 129 | img_ref = img_ref.convert('RGB') 130 | if not img_tar.mode == 'RGB': 131 | img_tar = img_tar.convert('RGB') 132 | 133 | if self.is_train: 134 | angle = random.choice([0, 90, 180, 270]) 135 | img_ref = TF.rotate(img_ref, angle) 136 | img_tar = TF.rotate(img_tar, angle) 137 | 138 | img_ref = self.transform(img_ref) 139 | img_tar = self.transform(img_tar) 140 | 141 | label = each_couple_list[2] 142 | return img_ref, img_tar, label 143 | # return img_tar, label 144 | 145 | 146 | class SequenceDataset(Dataset): 147 | def __init__(self, folder_dir, transform): 148 | """ 149 | :param folder_dir: e.g. "/home/ubuntu/projects/data/deepsolar2/cleaned/sequence_0/6083_128565_2012" 150 | """ 151 | self.path_list = [] 152 | self.transform = transform 153 | for f in sorted(os.listdir(folder_dir)): 154 | if f[-4:] in ['.png', '.jpg', 'JPEG', 'jpeg']: 155 | self.path_list.append(join(folder_dir, f)) 156 | 157 | def __len__(self): 158 | return len(self.path_list) 159 | 160 | def __getitem__(self, idx): 161 | img_path = self.path_list[idx] 162 | image = Image.open(img_path) 163 | if not image.mode == 'RGB': 164 | image = image.convert('RGB') 165 | image = self.transform(image) 166 | sample = [image, img_path.split('/')[-1]] 167 | return sample 168 | 169 | 170 | class FolderDirsDataset(Dataset): 171 | def __init__(self, dirs_list, transform): 172 | """ 173 | :param dirs_list: list. Length: number of classes. Each entries is a list of directories belonging to its class. 174 | """ 175 | self.sample_list = [] 176 | self.transform = transform 177 | for i, dirs in enumerate(dirs_list): 178 | for dir in dirs: 179 | for f in os.listdir(dir): 180 | if f[-4:] in ['.png', '.jpg', 'JPEG', 'jpeg']: 181 | self.sample_list.append((join(dir, f), i)) 182 | 183 | def __len__(self): 184 | return len(self.sample_list) 185 | 186 | def __getitem__(self, idx): 187 | img_path, class_idx = self.sample_list[idx] 188 | image = Image.open(img_path) 189 | if not image.mode == 'RGB': 190 | image = image.convert('RGB') 191 | image = self.transform(image) 192 | sample = [image, class_idx] 193 | return sample 194 | 195 | 196 | class FolderDirsDatasetMultiLabels(Dataset): 197 | def __init__(self, dirs_list, transform): 198 | """ 199 | :param dirs_list: list. Length: number of classes. Each entries is a list of directories belonging to its class. 200 | """ 201 | self.sample_list = [] 202 | self.transform = transform 203 | for i, dirs in enumerate(dirs_list): 204 | for dir in dirs: 205 | for f in os.listdir(dir): 206 | if f[-4:] in ['.png', '.jpg', 'JPEG', 'jpeg']: 207 | if i == 0: 208 | self.sample_list.append((join(dir, f), 0, 0)) 209 | elif i == 1: 210 | self.sample_list.append((join(dir, f), 1, 0)) 211 | else: 212 | self.sample_list.append((join(dir, f), 1, 1)) 213 | 214 | def __len__(self): 215 | return len(self.sample_list) 216 | 217 | def __getitem__(self, idx): 218 | img_path, class_idx_1, class_idx_2 = self.sample_list[idx] 219 | image = Image.open(img_path) 220 | if not image.mode == 'RGB': 221 | image = image.convert('RGB') 222 | image = self.transform(image) 223 | sample = [image, torch.tensor([class_idx_1, class_idx_2], dtype=torch.float)] 224 | return sample 225 | 226 | 227 | class FolderDirsDatasetMultiLabelsForSolarTypes(Dataset): 228 | def __init__(self, dirs_list, transform): 229 | """ 230 | :param dirs_list: list. Length: number of classes. Each entries is a list of directories belonging to its class. 231 | """ 232 | self.sample_list = [] 233 | self.transform = transform 234 | for i, dirs in enumerate(dirs_list): 235 | for dir in dirs: 236 | for f in os.listdir(dir): 237 | if f[-4:] in ['.png', '.jpg', 'JPEG', 'jpeg']: 238 | if i == 0: 239 | self.sample_list.append((join(dir, f), [0, 0, 0, 0])) # negative 240 | elif i == 1: 241 | self.sample_list.append((join(dir, f), [1, 0, 0, 0])) # solar water heating 242 | elif i == 2: 243 | self.sample_list.append((join(dir, f), [1, 1, 0, 0])) # residential solar 244 | elif i == 3: 245 | self.sample_list.append((join(dir, f), [1, 1, 1, 0])) # commercial solar 246 | else: 247 | self.sample_list.append((join(dir, f), [1, 1, 1, 1])) # utility-scale solar 248 | 249 | def __len__(self): 250 | return len(self.sample_list) 251 | 252 | def __getitem__(self, idx): 253 | img_path, class_labels = self.sample_list[idx] 254 | # print(img_path) 255 | # class_idx_1, class_idx_2, class_idx_3, class_idx_4 = class_labels 256 | image = Image.open(img_path) 257 | if not image.mode == 'RGB': 258 | image = image.convert('RGB') 259 | image = self.transform(image) 260 | sample = [image, torch.tensor(class_labels, dtype=torch.float)] 261 | return sample 262 | -------------------------------------------------------------------------------- /utils/inception_modified.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import division 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | from torch.utils.data import Dataset, DataLoader 7 | import torchvision 8 | from torchvision import datasets, models, transforms, utils 9 | import torchvision.transforms.functional as TF 10 | 11 | from tqdm import tqdm 12 | import numpy as np 13 | import json 14 | import pandas as pd 15 | import pickle 16 | import matplotlib.pyplot as plt 17 | # import skimage 18 | # import skimage.io 19 | # import skimage.transform 20 | from PIL import Image 21 | import time 22 | import os 23 | from os.path import join, exists 24 | import copy 25 | import random 26 | from collections import OrderedDict 27 | from sklearn.metrics import r2_score 28 | 29 | 30 | import torch.nn.functional as F 31 | from torchvision.models import Inception3 32 | from collections import namedtuple 33 | 34 | _InceptionOuputs = namedtuple('InceptionOuputs', ['logits', 'aux_logits']) 35 | 36 | 37 | class InceptionSegmentation(nn.Module): 38 | def __init__(self, num_outputs=2, level=1): 39 | super(InceptionSegmentation, self).__init__() 40 | assert level in [1, 2] 41 | self.level = level 42 | self.inception3 = Inception3_modified(num_classes=num_outputs, aux_logits=False, transform_input=False) 43 | self.convolution1 = nn.Conv2d(288, 512, bias=True, kernel_size=3, padding=1) 44 | if self.level == 1: 45 | self.linear1 = nn.Linear(512, num_outputs, bias=False) 46 | else: 47 | self.convolution2 = nn.Conv2d(512, 512, bias=True, kernel_size=3, padding=1) 48 | self.linear2 = nn.Linear(512, num_outputs, bias=False) 49 | 50 | def forward(self, x, testing=False): 51 | logits, intermediate = self.inception3(x) 52 | feature_map = self.convolution1(intermediate) # N x 512 x 35 x 35 53 | feature_map = F.relu(feature_map) # N x 512 x 35 x 35 54 | if self.level == 1: 55 | y = F.adaptive_avg_pool2d(feature_map, (1, 1)) 56 | y = y.view(y.size(0), -1) # N x 512 57 | y = self.linear1(y) # N x 2 58 | if testing: 59 | CAM = self.linear1.weight.data[1, :] * feature_map.permute(0, 2, 3, 1) 60 | CAM = CAM.sum(dim=3) 61 | else: 62 | feature_map = self.convolution2(feature_map) # N x 512 x 35 x 35 63 | feature_map = F.relu(feature_map) # N x 512 x 35 x 35 64 | y = F.adaptive_avg_pool2d(feature_map, (1, 1)) 65 | y = y.view(y.size(0), -1) # N x 512 66 | y = self.linear2(y) # N x 2 67 | if testing: 68 | CAM = self.linear2.weight.data[1, :] * feature_map.permute(0, 2, 3, 1) 69 | CAM = CAM.sum(dim=3) 70 | if testing: 71 | return y, logits, CAM 72 | else: 73 | return y 74 | 75 | def load_basic_params(self, model_path, device=torch.device('cpu')): 76 | """Only load the parameters from main branch.""" 77 | old_params = torch.load(model_path, map_location=device) 78 | if model_path[-4:] == '.tar': # The file is not a model state dict, but a checkpoint dict 79 | old_params = old_params['model_state_dict'] 80 | self.inception3.load_state_dict(old_params, strict=False) 81 | print('Loaded basic model parameters from: ' + model_path) 82 | 83 | def load_existing_params(self, model_path, device=torch.device('cpu')): 84 | """Load the parameters of main branch and parameters of level-1 layers (and perhaps level-2 layers.)""" 85 | old_params = torch.load(model_path, map_location=device) 86 | if model_path[-4:] == '.tar': # The file is not a model state dict, but a checkpoint dict 87 | old_params = old_params['model_state_dict'] 88 | self.load_state_dict(old_params, strict=False) 89 | print('Loaded existing model parameters from: ' + model_path) 90 | 91 | 92 | class Inception3_modified(Inception3): 93 | def forward(self, x): 94 | if self.transform_input: 95 | x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5 96 | x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5 97 | x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5 98 | x = torch.cat((x_ch0, x_ch1, x_ch2), 1) 99 | # N x 3 x 299 x 299 100 | x = self.Conv2d_1a_3x3(x) 101 | # N x 32 x 149 x 149 102 | x = self.Conv2d_2a_3x3(x) 103 | # N x 32 x 147 x 147 104 | x = self.Conv2d_2b_3x3(x) 105 | # N x 64 x 147 x 147 106 | x = F.max_pool2d(x, kernel_size=3, stride=2) 107 | # N x 64 x 73 x 73 108 | x = self.Conv2d_3b_1x1(x) 109 | # N x 80 x 73 x 73 110 | x = self.Conv2d_4a_3x3(x) 111 | # N x 192 x 71 x 71 112 | x = F.max_pool2d(x, kernel_size=3, stride=2) 113 | # N x 192 x 35 x 35 114 | x = self.Mixed_5b(x) 115 | # N x 256 x 35 x 35 116 | x = self.Mixed_5c(x) 117 | # N x 288 x 35 x 35 118 | x = self.Mixed_5d(x) 119 | # N x 288 x 35 x 35 120 | intermediate = x.clone() 121 | x = self.Mixed_6a(x) 122 | # N x 768 x 17 x 17 123 | x = self.Mixed_6b(x) 124 | # N x 768 x 17 x 17 125 | x = self.Mixed_6c(x) 126 | # N x 768 x 17 x 17 127 | x = self.Mixed_6d(x) 128 | # N x 768 x 17 x 17 129 | x = self.Mixed_6e(x) 130 | # N x 768 x 17 x 17 131 | if self.training and self.aux_logits: 132 | aux = self.AuxLogits(x) 133 | # N x 768 x 17 x 17 134 | x = self.Mixed_7a(x) 135 | # N x 1280 x 8 x 8 136 | x = self.Mixed_7b(x) 137 | # N x 2048 x 8 x 8 138 | x = self.Mixed_7c(x) 139 | # N x 2048 x 8 x 8 140 | # Adaptive average pooling 141 | x = F.adaptive_avg_pool2d(x, (1, 1)) 142 | # N x 2048 x 1 x 1 143 | x = F.dropout(x, training=self.training) 144 | # N x 2048 x 1 x 1 145 | x = x.view(x.size(0), -1) 146 | # N x 2048 147 | x = self.fc(x) 148 | # N x 1000 (num_classes) 149 | if self.training and self.aux_logits: 150 | return _InceptionOuputs(x, aux) 151 | return x, intermediate --------------------------------------------------------------------------------