├── .gitattributes ├── .github └── workflows │ └── greetings.yml ├── .gitignore ├── AutoEncoder.py ├── CutTarget.py ├── DataAugment.py ├── Datasets.py ├── GAN.py ├── GenerateTestDataset.py ├── ImagePreProcessing.py ├── ImageTools.py ├── LICENSE ├── MetaLearning.py ├── ModelLoader.py ├── README.md ├── TargetDiscriminator.py ├── VAE_GAN_train.py ├── main.py ├── requirements.txt ├── resource └── effect.png └── segment_anything ├── __init__.py ├── automatic_mask_generator.py ├── build_sam.py ├── modeling ├── __init__.py ├── common.py ├── image_encoder.py ├── mask_decoder.py ├── prompt_encoder.py ├── sam.py └── transformer.py ├── predictor.py └── utils ├── __init__.py ├── amg.py ├── onnx.py └── transforms.py /.gitattributes: -------------------------------------------------------------------------------- 1 | *.pth filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /.github/workflows/greetings.yml: -------------------------------------------------------------------------------- 1 | name: Greetings 2 | 3 | on: [pull_request_target, issues] 4 | 5 | jobs: 6 | greeting: 7 | runs-on: ubuntu-latest 8 | permissions: 9 | issues: write 10 | pull-requests: write 11 | steps: 12 | - uses: actions/first-interaction@v1 13 | with: 14 | repo-token: ${{ secrets.GITHUB_TOKEN }} 15 | issue-message: "Message that will be displayed on users' first issue" 16 | pr-message: "Message that will be displayed on users' first pull request" 17 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | .idea/ 161 | 162 | # project 163 | saved_model/ 164 | input/test.bmp 165 | output/ -------------------------------------------------------------------------------- /AutoEncoder.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.nn.init as init 7 | import torch.optim as optim 8 | from torch.optim.lr_scheduler import CosineAnnealingLR 9 | 10 | from ImagePreProcessing import * 11 | from ImageTools import * 12 | from ModelLoader import ModelLoader 13 | 14 | latent_dim = 256 15 | input_channel = 1 16 | # max_size = 1536 17 | max_size = 1792 18 | 19 | 20 | class SELayer(nn.Module): 21 | def __init__(self, channel, reduction=16): 22 | super(SELayer, self).__init__() 23 | # Squeeze operation 24 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 25 | # Excitation operation 26 | self.fc = nn.Sequential( 27 | nn.Linear(channel, channel // reduction, bias=False), 28 | nn.ReLU(), 29 | nn.Linear(channel // reduction, channel, bias=False), 30 | nn.Sigmoid() 31 | ) 32 | 33 | def forward(self, x): 34 | b, c, _, _ = x.size() 35 | y = self.avg_pool(x).view(b, c) 36 | y_fc = self.fc(y).view(b, c, 1, 1) 37 | y_ex = y_fc.expand(x.size()) 38 | return x * y_ex 39 | 40 | 41 | class SpatialPyramidPooling(nn.Module): 42 | def __init__(self, pool_sizes): 43 | super(SpatialPyramidPooling, self).__init__() 44 | self.pool_sizes = pool_sizes 45 | 46 | def forward(self, x): 47 | features = [] 48 | for pool_size in self.pool_sizes: 49 | features.append(F.adaptive_avg_pool2d(x, pool_size).view(x.size(0), -1)) 50 | return torch.cat(features, 1) 51 | 52 | 53 | class VAEEncoder(nn.Module): 54 | def __init__(self, z_dim): 55 | super(VAEEncoder, self).__init__() 56 | 57 | # Initial batch normalization 58 | self.initial_norm = nn.InstanceNorm2d(input_channel) 59 | 60 | # Encoder architecture: Five sets of [BN -> Conv2d] for RGB images 61 | self.conv1 = nn.Sequential( 62 | nn.Conv2d(input_channel, max_size // 16, kernel_size=4, stride=2, padding=1), 63 | # Output: [max_size // 16, 256, 256] 64 | nn.InstanceNorm2d(max_size // 16), 65 | SELayer(max_size // 16), 66 | nn.LeakyReLU(0.2), 67 | ) 68 | self.conv2 = nn.Sequential( 69 | nn.Conv2d(max_size // 16, max_size // 8, kernel_size=4, stride=2, padding=1), 70 | # Output: [max_size // 8, 128, 128] 71 | nn.InstanceNorm2d(max_size // 8), 72 | SELayer(max_size // 8), 73 | nn.LeakyReLU(0.2), 74 | ) 75 | self.conv3 = nn.Sequential( 76 | nn.Conv2d(max_size // 8, max_size // 4, kernel_size=4, stride=2, padding=1), 77 | # Output: [max_size // 4, 64, 64] 78 | nn.InstanceNorm2d(max_size // 4), 79 | SELayer(max_size // 4), 80 | nn.LeakyReLU(0.2), 81 | ) 82 | self.conv4 = nn.Sequential( 83 | nn.Conv2d(max_size // 4, max_size // 2, kernel_size=4, stride=2, padding=1), 84 | # Output: [max_size // 2, 32, 32] 85 | nn.InstanceNorm2d(max_size // 2), 86 | SELayer(max_size // 2), 87 | nn.LeakyReLU(0.2), 88 | ) 89 | self.conv5 = nn.Sequential( 90 | nn.Conv2d(max_size // 2, max_size, kernel_size=4, stride=2, padding=1), 91 | # Output: [max_size, 16, 16] 92 | nn.InstanceNorm2d(max_size), 93 | SELayer(max_size), 94 | nn.LeakyReLU(0.2), 95 | ) 96 | 97 | # Residual connections 98 | self.skip1 = nn.Conv2d(input_channel, max_size // 16, kernel_size=1, stride=2, padding=0) 99 | self.skip2 = nn.Conv2d(max_size // 16, max_size // 8, kernel_size=1, stride=2, padding=0) 100 | self.skip3 = nn.Conv2d(max_size // 8, max_size // 4, kernel_size=1, stride=2, padding=0) 101 | self.skip4 = nn.Conv2d(max_size // 4, max_size // 2, kernel_size=1, stride=2, padding=0) 102 | self.skip5 = nn.Conv2d(max_size // 2, max_size, kernel_size=1, stride=2, padding=0) 103 | 104 | # Add SPP layer 105 | self.spp = SpatialPyramidPooling([1, 2, 4]) 106 | # Calculate the flattened size after the SPP layer 107 | spp_total_size = max_size * (1 * 1 + 2 * 2 + 4 * 4) 108 | self.fc_mu = nn.Linear(spp_total_size, z_dim) 109 | self.fc_var = nn.Linear(spp_total_size, z_dim) 110 | 111 | # 初始化 112 | init.kaiming_uniform_(self.conv1[0].weight, a=0.2, nonlinearity='leaky_relu') 113 | init.kaiming_uniform_(self.conv2[0].weight, a=0.2, nonlinearity='leaky_relu') 114 | init.kaiming_uniform_(self.conv3[0].weight, a=0.2, nonlinearity='leaky_relu') 115 | init.kaiming_uniform_(self.conv4[0].weight, a=0.2, nonlinearity='leaky_relu') 116 | init.kaiming_uniform_(self.conv5[0].weight, a=0.2, nonlinearity='leaky_relu') 117 | init.xavier_uniform_(self.fc_mu.weight) 118 | self.fc_mu.bias.data.fill_(0) 119 | init.xavier_uniform_(self.fc_var.weight) 120 | self.fc_var.bias.data.fill_(0) 121 | 122 | def forward(self, x): 123 | # Apply initial batch normalization 124 | x_ini = self.initial_norm(x) 125 | 126 | # Apply the five sets of [Conv2d -> BN -> ReLU] with residual connections 127 | identity1 = self.skip1(x_ini) 128 | x1 = self.conv1(x_ini) + identity1 129 | 130 | identity2 = self.skip2(x1) 131 | x2 = self.conv2(x1) + identity2 132 | 133 | identity3 = self.skip3(x2) 134 | x3 = self.conv3(x2) + identity3 135 | 136 | identity4 = self.skip4(x3) 137 | x4 = self.conv4(x3) + identity4 138 | 139 | identity5 = self.skip5(x4) 140 | x5 = self.conv5(x4) + identity5 141 | 142 | # Pass through the SPP layer 143 | x_spp = self.spp(x5) 144 | 145 | # Flatten the output for the fully connected layers 146 | x_final = x_spp.view(x_spp.size(0), -1) 147 | 148 | # Pass through the fully connected layers 149 | z_mu = self.fc_mu(x_final) 150 | z_var = self.fc_var(x_final) 151 | 152 | return z_mu, z_var 153 | 154 | 155 | class VAEDecoder(nn.Module): 156 | def __init__(self, z_dim): 157 | super(VAEDecoder, self).__init__() 158 | 159 | self.feature_map_size = max_size * 16 * 16 # [max_size, 16, 16] 160 | 161 | self.fc = nn.Linear(z_dim, self.feature_map_size) 162 | 163 | self.conv_transpose1 = nn.Sequential( 164 | nn.ConvTranspose2d(max_size, max_size // 2, kernel_size=4, stride=2, padding=1), 165 | nn.InstanceNorm2d(max_size // 2), 166 | SELayer(max_size // 2), 167 | nn.LeakyReLU(0.2), 168 | ) 169 | self.conv_transpose2 = nn.Sequential( 170 | nn.ConvTranspose2d(max_size // 2, max_size // 4, kernel_size=4, stride=2, padding=1), 171 | nn.InstanceNorm2d(max_size // 4), 172 | SELayer(max_size // 4), 173 | nn.LeakyReLU(0.2), 174 | ) 175 | self.conv_transpose3 = nn.Sequential( 176 | nn.ConvTranspose2d(max_size // 4, max_size // 8, kernel_size=4, stride=2, padding=1), 177 | nn.InstanceNorm2d(max_size // 8), 178 | SELayer(max_size // 8), 179 | nn.LeakyReLU(0.2), 180 | ) 181 | self.conv_transpose4 = nn.Sequential( 182 | nn.ConvTranspose2d(max_size // 8, max_size // 16, kernel_size=4, stride=2, padding=1), 183 | nn.InstanceNorm2d(max_size // 16), 184 | SELayer(max_size // 16), 185 | nn.LeakyReLU(0.2), 186 | ) 187 | self.conv_transpose5 = nn.Sequential( 188 | nn.ConvTranspose2d(max_size // 16, input_channel, kernel_size=4, stride=2, padding=1), 189 | nn.Tanh() # 使用Tanh激活函数将像素值限制在-1到1之间 190 | ) 191 | 192 | # 残差连接 193 | self.skip1 = nn.ConvTranspose2d(max_size, max_size // 2, kernel_size=4, stride=2, padding=1) 194 | self.skip2 = nn.ConvTranspose2d(max_size // 2, max_size // 4, kernel_size=4, stride=2, padding=1) 195 | self.skip3 = nn.ConvTranspose2d(max_size // 4, max_size // 8, kernel_size=4, stride=2, padding=1) 196 | self.skip4 = nn.ConvTranspose2d(max_size // 8, max_size // 16, kernel_size=4, stride=2, padding=1) 197 | self.skip5 = nn.ConvTranspose2d(max_size // 16, input_channel, kernel_size=4, stride=2, padding=1) 198 | 199 | self._initialize_weights() 200 | 201 | def forward(self, x): 202 | x_fc = self.fc(x) 203 | x_viewed = x_fc.view(-1, max_size, 16, 16) 204 | 205 | identity1 = self.skip1(x_viewed) 206 | x1 = self.conv_transpose1(x_viewed) + identity1 207 | 208 | identity2 = self.skip2(x1) 209 | x2 = self.conv_transpose2(x1) + identity2 210 | 211 | identity3 = self.skip3(x2) 212 | x3 = self.conv_transpose3(x2) + identity3 213 | 214 | identity4 = self.skip4(x3) 215 | x4 = self.conv_transpose4(x3) + identity4 216 | 217 | identity5 = self.skip5(x4) 218 | x5 = self.conv_transpose5(x4) + identity5 219 | 220 | return x5 221 | 222 | def _initialize_weights(self): 223 | init.kaiming_uniform_(self.conv_transpose1[0].weight, a=0.2, nonlinearity='leaky_relu') 224 | init.kaiming_uniform_(self.conv_transpose2[0].weight, a=0.2, nonlinearity='leaky_relu') 225 | init.kaiming_uniform_(self.conv_transpose3[0].weight, a=0.2, nonlinearity='leaky_relu') 226 | init.kaiming_uniform_(self.conv_transpose4[0].weight, a=0.2, nonlinearity='leaky_relu') 227 | init.xavier_uniform_(self.conv_transpose5[0].weight) 228 | self.conv_transpose5[0].bias.data.fill_(0) 229 | init.xavier_uniform_(self.fc.weight) 230 | self.fc.bias.data.fill_(0) 231 | 232 | 233 | class VAEModel(nn.Module): 234 | def __init__(self, encoder: VAEEncoder, decoder: VAEDecoder): 235 | super(VAEModel, self).__init__() 236 | 237 | self.encoder = encoder 238 | self.decoder = decoder 239 | 240 | def forward(self, x): 241 | # encode 242 | z_mu, z_var = self.encoder(x) 243 | 244 | # sample from the distribution having latent parameters z_mu, z_var 245 | # reparameterize 246 | std = torch.exp(z_var / 2) 247 | eps = torch.randn_like(std) 248 | x_sample = eps.mul(std) + z_mu 249 | 250 | # decode 251 | predicted = self.decoder(x_sample) 252 | return predicted, z_mu, z_var 253 | 254 | 255 | class VAEModelLoader(ModelLoader): 256 | def __init__(self, train_dataset, test_dataset, batch_size, model_path: str, if_early_stop=False, debug_mode=False): 257 | super().__init__(train_dataset, test_dataset, batch_size, model_path, if_early_stop, debug_mode) 258 | print('-' * 10, 'Loading VAE model', '-' * 10) 259 | # encoder 260 | self.latent_dim = latent_dim # latent vector dimension 261 | encoder = VAEEncoder(self.latent_dim) 262 | # decoder 263 | decoder = VAEDecoder(self.latent_dim) 264 | # VAE 265 | self.model = VAEModel(encoder, decoder).to(self.device) 266 | self.lr = 1e-4 # learning rate 267 | # optimizer 268 | self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr) 269 | # 设置学习率调度 270 | # 余弦调度 271 | self.scheduler = CosineAnnealingLR(self.optimizer, T_max=50, eta_min=1e-5) 272 | 273 | # load exist model 274 | self.load_model() 275 | 276 | self.train_losses = [] 277 | self.test_losses = [] 278 | 279 | def _train_epoch(self): 280 | # set the train mode 281 | self.model.train() 282 | # loss of the epoch 283 | train_loss = 0 284 | for i, x in enumerate(self.train_iterator): 285 | # reshape the data into [batch_size, 3, 512, 512] 286 | x = x.view(-1, input_channel, 512, 512) # 后面需要conv,所以先调整size 287 | x = x.to(self.device) 288 | 289 | self.optimizer.zero_grad() 290 | 291 | x_sample, z_mu, z_var = self.model(x) 292 | 293 | # reconstruction loss 294 | # recon_loss = F.binary_cross_entropy(x_sample, x, reduction='sum') 295 | recon_loss = F.mse_loss(x_sample, x, reduction='sum') 296 | 297 | # kl divergence loss 298 | kl_loss = 0.5 * torch.sum(torch.exp(z_var) + z_mu ** 2 - 1.0 - z_var) 299 | 300 | # total loss 301 | loss = recon_loss + kl_loss 302 | 303 | loss.backward() 304 | train_loss += loss.item() 305 | 306 | self.optimizer.step() 307 | 308 | print(f'Train batch {i}, loss: {loss.item() / self.batch_size}') 309 | 310 | return train_loss / len(self.train_dataset) 311 | 312 | def _test_epoch(self): 313 | self.model.eval() 314 | test_loss = 0 315 | with torch.no_grad(): 316 | for i, x in enumerate(self.test_iterator): 317 | # reshape the data 318 | x = x.view(-1, input_channel, 512, 512) 319 | x = x.to(self.device) 320 | 321 | # forward pass 322 | x_sample, z_mu, z_var = self.model(x) 323 | 324 | # reconstruction loss 325 | # recon_loss = F.binary_cross_entropy(x_sample, x, reduction='sum') 326 | recon_loss = F.mse_loss(x_sample, x, reduction='sum') 327 | 328 | # kl divergence loss 329 | kl_loss = 0.5 * torch.sum(torch.exp(z_var) + z_mu ** 2 - 1.0 - z_var) 330 | 331 | # total loss 332 | loss = recon_loss + kl_loss 333 | test_loss += loss.item() 334 | 335 | print(f'Test batch {i}, loss: {loss.item() / self.batch_size}') 336 | 337 | return test_loss / len(self.test_dataset) 338 | 339 | def train(self, epochs=50, test_interval=10): 340 | if self.if_early_stop: 341 | # 早停策略防止过拟合 342 | best_test_loss = float('inf') 343 | patience_counter = 0 344 | 345 | for e in range(epochs): 346 | print('-' * 10, 'Train epoch', e, 'started!', '-' * 10) 347 | train_loss = self._train_epoch() 348 | test_loss = self._test_epoch() 349 | print(f'Epoch {e}, Train Loss: {train_loss:.2f}, Test Loss: {test_loss:.2f}') 350 | 351 | self.train_losses.append(train_loss) 352 | self.test_losses.append(test_loss) 353 | 354 | if self.scheduler: 355 | self.scheduler.step() 356 | # 保存模型 357 | self.save_model(test_loss) 358 | # 按照间隔测试模型 359 | if (e + 1) % test_interval == 0: 360 | # 获取test_imgs目录下的所有图片文件 361 | test_imgs_dir = "./test_imgs/" 362 | test_imgs_files = [f for f in os.listdir(test_imgs_dir) if f.endswith(".bmp")] 363 | 364 | # 对test_imgs目录下的图片进行重建测试 365 | for i, file in enumerate(test_imgs_files): 366 | file_path = os.path.join(test_imgs_dir, file) 367 | output_name = f"epoch_{e + 1}_test_img_{i + 1}" 368 | self.regenerate_test(Image.open(file_path), output_name) 369 | 370 | if self.if_early_stop: 371 | # 计算早停累计 372 | if best_test_loss > test_loss: 373 | best_test_loss = test_loss 374 | patience_counter = 1 375 | else: 376 | patience_counter += 1 377 | if patience_counter > max(epochs / 5, 10): 378 | # 早停 379 | print('Training interrupted to avoid overfitting.') 380 | break 381 | 382 | print(f'Final Train Loss: {self.train_losses[-1]:.2f}, Final Test Loss: {self.test_losses[-1]:.2f}') 383 | 384 | # 将 train_losses 和 test_losses 保存到文件中 385 | with open('losses.txt', 'w') as f: 386 | f.write('Train Losses:\n') 387 | f.write(', '.join(map(str, self.train_losses))) 388 | f.write('\n\nTest Losses:\n') 389 | f.write(', '.join(map(str, self.test_losses))) 390 | 391 | def random_generate_test(self, test_time=10, picture_name='test_image'): 392 | """ 393 | 根据特征向量分布随机生成特征向量并解码出图片 394 | :param test_time: 测试次数 395 | :param picture_name: 保存图片名称 396 | :return: None 397 | """ 398 | print('-' * 3, 'random_generate_test', '-' * 3) 399 | if not os.path.exists('./VAE_test/'): 400 | os.makedirs('./VAE_test/') 401 | 402 | self.model.eval() 403 | with torch.no_grad(): 404 | for i in range(test_time): 405 | z = torch.randn(1, self.latent_dim).to(self.device) 406 | reconstructed_img = self.model.decoder(z) 407 | img = reconstructed_img.cpu().squeeze(0) # 从batch中移除,得到3x512x512的图片 408 | img = img.permute(1, 2, 0) # 调整为512x512x3 409 | 410 | # 将张量数据转换为PIL图像 411 | img = (img.numpy() * 255).astype(np.uint8) 412 | img = Image.fromarray(img) 413 | 414 | filename = f'./VAE_test/{picture_name}_{i}.png' 415 | img.save(filename) 416 | print('result saved to', filename) 417 | 418 | def regenerate_test(self, input_image: Image, file_name: str): 419 | """ 420 | 根据输入图片解码并重新生成测试模型效果 421 | :param input_image: 输入图片 422 | :param file_name: 输出图片名称 423 | :return: None 424 | """ 425 | print('-' * 3, 'regenerate_test', '-' * 3) 426 | if not os.path.exists('./VAE_test/'): 427 | os.makedirs('./VAE_test/') 428 | # 定义预处理转换链 429 | if input_channel == 3: 430 | transform = transforms.Compose([ 431 | transforms.Resize((512, 512)), # 将图像调整为512x512 432 | transforms.Lambda(convert_to_rgb), # 确保图像为三通道 433 | transforms.ToTensor() # 将图像转换为PyTorch张量 434 | ]) 435 | elif input_channel == 1: 436 | cv_image = np.array(input_image.convert('RGB')) 437 | # 转换为BGR格式 438 | cv_image = cv_image[:, :, ::-1] 439 | cv_image = img_pre_processing_gray(cv_image) # 返回二值化后的图 440 | input_image = Image.fromarray(cv_image) 441 | transform = transforms.Compose([ 442 | transforms.Resize((512, 512)), # 将图像调整为512x512 443 | transforms.ToTensor() # 将图像转换为PyTorch张量 444 | ]) 445 | # 应用预处理转换链 446 | input_tensor = transform(input_image) 447 | 448 | # 添加批次维度并将图像输入模型 449 | input_tensor = input_tensor.to(self.device).unsqueeze(0) # 添加批次维度,即从C x H x W变为1 x C x H x W 450 | 451 | with torch.no_grad(): 452 | # 编码图像,获取潜在空间的均值和方差对数 453 | z_mu, z_log_var = self.model.encoder(input_tensor) 454 | 455 | # 从标准正态分布中采样epsilon 456 | # std = torch.exp(z_log_var / 2) 457 | # eps = torch.randn_like(std) 458 | # z = z_mu + eps * std 459 | 460 | # 解码 461 | regenerated_image = self.model.decoder(z_mu) 462 | 463 | img = regenerated_image.cpu().squeeze(0) # 从batch中移除,得到inputchannel x 512 x 512的图片 464 | # print(img.shape) # torch.Size([1, 512, 512]) 465 | img = img.permute(1, 2, 0) # 调整为512x512x input_channel 466 | # print(img.shape) # torch.Size([512, 512, 1]) 467 | 468 | if img.shape[2] == 1: 469 | img = img.squeeze() # 去除单一通道维度 470 | img_np = img.numpy() 471 | if img_np.dtype == np.float32 or img_np.dtype == np.float64: 472 | img_np = (img_np * 255).astype(np.uint8) 473 | img = Image.fromarray(img_np, mode='L') # 'L' 模式代表灰度图 474 | elif img.shape[2] == 3: 475 | img_np = (img.numpy() * 255).astype(np.uint8) 476 | img = Image.fromarray(img_np) 477 | 478 | # 图片名称:添加参数picture_name和索引i 479 | filename = f'./VAE_test/{file_name}.png' 480 | img.save(filename) # 直接保存图片 481 | print('result saved to', filename) 482 | 483 | def test(self): 484 | # 获取test_imgs目录下的所有图片文件 485 | test_imgs_dir = "./test_imgs/" 486 | test_imgs_files = [f for f in os.listdir(test_imgs_dir) if f.endswith(".bmp")] 487 | 488 | # 对test_imgs目录下的图片进行重建测试 489 | for i, file in enumerate(test_imgs_files): 490 | file_path = os.path.join(test_imgs_dir, file) 491 | output_name = f"test_result_{i + 1}" 492 | self.regenerate_test(Image.open(file_path), output_name) 493 | -------------------------------------------------------------------------------- /CutTarget.py: -------------------------------------------------------------------------------- 1 | import math 2 | import sys 3 | import cv2 4 | import matplotlib 5 | import numpy 6 | import numpy as np 7 | 8 | from TargetDiscriminator import * 9 | from segment_anything import sam_model_registry, SamAutomaticMaskGenerator 10 | 11 | sys.path.append("..") 12 | 13 | matplotlib.use('TkAgg') 14 | crop_mode = True # 是否裁剪到最小范围 15 | input_dir = 'input' 16 | output_dir = 'output' 17 | image_files = [f for f in os.listdir(input_dir) if 18 | f.lower().endswith(('.png', '.jpg', '.jpeg', '.JPG', '.JPEG', '.PNG', 'bmp'))] 19 | files_num = len(image_files) 20 | 21 | sam_checkpoint = "./saved_model/sam_vit_l_0b3195.pth" 22 | sam_model_type = "vit_l" 23 | 24 | device = "cuda" 25 | 26 | sam = sam_model_registry[sam_model_type](checkpoint=sam_checkpoint) 27 | sam.to(device=device) 28 | 29 | general_kernel1 = (100, 10) 30 | general_kernel2 = (150, 10) 31 | 32 | 33 | def pad_to_square(arr): 34 | h, w = arr.shape[:2] 35 | max_dim = max(h, w) 36 | 37 | top = (max_dim - h) // 2 38 | bottom = max_dim - h - top 39 | left = (max_dim - w) // 2 40 | right = max_dim - w - left 41 | 42 | padded_arr = np.pad(arr, ((top, bottom), (left, right), (0, 0)), mode='constant', constant_values=255) 43 | 44 | restore_info = { 45 | 'original_shape': (h, w), 46 | 'top': top, 47 | 'left': left 48 | } 49 | 50 | return padded_arr, restore_info 51 | 52 | 53 | def scale_image_to_fit_window(image, max_width=896, max_height=750): 54 | """ 55 | 将图片缩放以适应给定的最大宽度和高度。 56 | """ 57 | height, width = image.shape[:2] 58 | 59 | # 计算缩放比例 60 | scale_x = max_width / width 61 | scale_y = max_height / height 62 | scale = min(scale_x, scale_y) 63 | # 缩放图片 64 | scaled_image = cv2.resize(image, None, fx=scale, fy=scale) 65 | 66 | return scaled_image 67 | 68 | 69 | def apply_mask(image, mask, alpha_channel=True, kernel_size=()) -> tuple[np.ndarray, np.ndarray]: 70 | if (isinstance(kernel_size, tuple) or isinstance(kernel_size, list)) and len(kernel_size) == 2: 71 | # 将布尔类型的mask转换为uint8类型 72 | print('优化mask...') 73 | mask_processed = mask.astype(np.uint8) * 255 74 | 75 | # 应用腐蚀和膨胀操作 76 | kernel_1 = np.ones((kernel_size[0], kernel_size[0]), np.uint8) 77 | kernel_2 = np.ones((kernel_size[1], kernel_size[1]), np.uint8) 78 | 79 | mask_eroded = cv2.erode(mask_processed, kernel_1, iterations=1) 80 | 81 | # 对腐蚀后的图像进行连通域分析 82 | num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(mask_eroded, connectivity=8) 83 | 84 | # 设置连通域面积阈值 85 | area_threshold = 10000 86 | 87 | # 创建一个新的mask用于存储处理后的结果 88 | mask_processed = np.zeros_like(mask_eroded) 89 | 90 | # 只保留面积大于阈值的连通域 91 | for i in range(1, num_labels): 92 | if stats[i, cv2.CC_STAT_AREA] >= area_threshold: 93 | mask_processed[labels == i] = 255 94 | mask_dilated = cv2.dilate(mask_processed, kernel_1, iterations=1) 95 | mask_dilated = cv2.dilate(mask_dilated, kernel_2, iterations=1) 96 | mask_processed = cv2.erode(mask_dilated, kernel_2, iterations=1) 97 | 98 | # 显示处理前后的图像 99 | # scale_factor = 0.5 100 | # resized_original_mask = cv2.resize(mask.astype(np.uint8) * 255, None, fx=scale_factor, fy=scale_factor) 101 | # resized_processed_mask = cv2.resize(mask_processed, None, fx=scale_factor, fy=scale_factor) 102 | # cv2.imshow("Original Mask", resized_original_mask) 103 | # cv2.imshow("Processed Mask", resized_processed_mask) 104 | # cv2.waitKey(0) 105 | # cv2.destroyAllWindows() 106 | 107 | # 将处理后的mask转换回布尔类型 108 | mask = mask_processed.astype(bool) 109 | 110 | if alpha_channel: 111 | alpha = np.zeros_like(image[..., 0]) # 制作掩体 112 | alpha[mask] = 255 # 兴趣地方标记为1,且为白色 113 | image = cv2.merge((image[..., 0], image[..., 1], image[..., 2], alpha)) # 融合图像 114 | image = Image.fromarray(image).convert("RGBA") # 读取图像并转换为RGB模式 115 | 116 | # 如果图像有alpha通道,将RGBA图像转换为灰度图,并将透明部分填充为白色 117 | background = Image.new("RGBA", image.size, (255, 255, 255)) 118 | image = np.array(Image.alpha_composite(background, image)) 119 | else: 120 | image = np.where(mask[..., None], image, 0) 121 | 122 | # 显示原图和处理后的图像 123 | # scale_factor = 0.5 124 | # resized_original_image = cv2.resize(image, None, fx=scale_factor, fy=scale_factor) 125 | # resized_processed_image = cv2.resize(np.array(image), None, fx=scale_factor, fy=scale_factor) 126 | # cv2.imshow("Original Image", resized_original_image) 127 | # cv2.imshow("Processed Image", resized_processed_image) 128 | # cv2.waitKey(0) 129 | # cv2.destroyAllWindows() 130 | return image, mask 131 | 132 | 133 | def get_next_filename(base_path, filename): # 进行下一个图像 134 | name, ext = os.path.splitext(filename) 135 | for i in range(1, 101): 136 | new_name = f"{name}_{i}{ext}" 137 | if not os.path.exists(os.path.join(base_path, new_name)): 138 | return new_name 139 | return None 140 | 141 | 142 | def save_masked_image(image, mask, output_dir, filename, crop_mode_, kernel_size): # 保存掩盖部分的图像(感兴趣的图像) 143 | height, width = image.shape[:2] 144 | if crop_mode_: 145 | y, x = np.where(mask) 146 | y_min, y_max, x_min, x_max = y.min(), y.max(), x.min(), x.max() 147 | cropped_mask = mask[y_min:y_max + 1, x_min:x_max + 1] 148 | cropped_image = image[y_min:y_max + 1, x_min:x_max + 1] 149 | masked_image, cropped_mask = apply_mask(cropped_image, cropped_mask, kernel_size=kernel_size) 150 | masked_image, info = pad_to_square(masked_image) 151 | print(masked_image.shape) 152 | else: 153 | masked_image, mask = apply_mask(image, mask, kernel_size=kernel_size) 154 | filename = filename[:filename.rfind('.')] + '.png' 155 | new_filename = get_next_filename(output_dir, filename) 156 | 157 | if new_filename: 158 | if masked_image.shape[-1] == 4: 159 | cv2.imwrite(os.path.join(output_dir, new_filename), masked_image, [cv2.IMWRITE_PNG_COMPRESSION, 9]) 160 | else: 161 | cv2.imwrite(os.path.join(output_dir, new_filename), masked_image) 162 | print(f"Saved as {new_filename}") 163 | else: 164 | print("Could not save the image. Too many variations exist.") 165 | 166 | 167 | def show_anns(anns): 168 | if len(anns) == 0: 169 | return 170 | sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True) 171 | return sorted_anns 172 | 173 | 174 | def cut_target_from_image(origin_image, reason_mode='cuda', area_lower_limit=30000, area_upper_limit=math.inf, 175 | pcb_prob=0.2): 176 | """ 177 | Input: 178 | origin image: PIL image 179 | reason mode: cpu or gpu 180 | area lower limit: the min area of target object, default = 30000 181 | area upper limit: the max area of target object, default = positive infinity 182 | pcb prob: the probability of cropped image being a real PCB, default = 0.2 183 | Output: 184 | masked_image: origin size 185 | coordinates: (x_min, x_max, y_min, y_max), the location of the selected PCB in the original image 186 | cropped mask: in order to eliminate the influence of the shadow in the reconstruction 187 | """ 188 | test_device = reason_mode 189 | sam.to(device=test_device) 190 | target_discriminator = TargetDiscriminator('./saved_model/Discriminator_trained.pth', device=reason_mode) 191 | image = numpy.array(origin_image) 192 | image_crop = image.copy() 193 | image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 194 | image = scale_image_to_fit_window(image_rgb) 195 | info = None 196 | 197 | mask_generator = SamAutomaticMaskGenerator(sam) 198 | masks = mask_generator.generate(image) 199 | 200 | for j in range(len(masks)): 201 | if area_lower_limit < masks[j]['area'] < area_upper_limit: 202 | target_mask = masks[j]['segmentation'] 203 | else: 204 | continue 205 | 206 | binary_image = np.uint8(target_mask) * 255 207 | resized_image = cv2.resize(binary_image, (image_crop.shape[1], image_crop.shape[0])) 208 | target_mask = resized_image > 0 209 | 210 | # masked_image = apply_mask(image_crop, target_mask) 211 | 212 | y, x = np.where(target_mask) 213 | y_min, y_max, x_min, x_max = y.min(), y.max(), x.min(), x.max() 214 | cropped_mask = target_mask[y_min:y_max + 1, x_min:x_max + 1] 215 | cropped_image = image_crop[y_min:y_max + 1, x_min:x_max + 1] 216 | masked_image_cut, cropped_mask = apply_mask(cropped_image, cropped_mask) 217 | prob_pcb = target_discriminator.predict(Image.fromarray(masked_image_cut)) 218 | print(f'The probability of being a real PCB is: {prob_pcb:.4f}') 219 | if prob_pcb > pcb_prob: 220 | if y_max - y_min > x_max - x_min: 221 | diff = (y_max - y_min) - (x_max - x_min) 222 | if diff % 2 == 0: 223 | x_max += diff // 2 224 | x_min -= diff // 2 225 | else: 226 | x_max += (diff // 2) + 1 227 | x_min -= (diff // 2) - 1 228 | if x_max - x_min > y_max - y_min: 229 | diff = (x_max - x_min) - (y_max - y_min) 230 | if diff % 2 == 0: 231 | y_max += diff // 2 232 | y_min -= diff // 2 233 | else: 234 | y_max += (diff // 2) + 1 235 | y_min -= (diff // 2) - 1 236 | cropped_mask = target_mask[y_min:y_max + 1, x_min:x_max + 1] 237 | cropped_image = image_crop[y_min:y_max + 1, x_min:x_max + 1] 238 | masked_image_cut, cropped_mask = apply_mask(cropped_image, cropped_mask) 239 | break 240 | return Image.fromarray(masked_image_cut), [x_min, x_max, y_min, y_max], cropped_mask, info 241 | 242 | 243 | if __name__ == '__main__': 244 | for i in range(files_num): 245 | print("第{}张图:".format(i + 1)) 246 | filename = image_files[i] 247 | image = cv2.imread(os.path.join(input_dir, filename)) 248 | image_crop = image.copy() 249 | # image_crop = scale_image_to_fit_window(image).copy() 250 | image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 251 | image = scale_image_to_fit_window(image_rgb) 252 | # plt.figure(figsize=(20,20)) 253 | # plt.imshow(image) 254 | 255 | mask_generator = SamAutomaticMaskGenerator(sam) 256 | 257 | masks = mask_generator.generate(image) 258 | 259 | # plt.figure(figsize=(20, 20)) 260 | # plt.imshow(image) 261 | 262 | masks_list = show_anns(masks) 263 | 264 | # 遍历分析每个预测的掩码 265 | for _, mask in enumerate(masks): 266 | print(f"Mask {_}:") 267 | for key, value in mask.items(): 268 | print(f"{key}: {value}") 269 | print("---") 270 | 271 | target_mask = None 272 | MIN_AREA = 40000 273 | for j in range(len(masks_list)): 274 | if MIN_AREA < masks[j]['area']: 275 | target_mask = masks[j]['segmentation'] 276 | 277 | if target_mask is not None: 278 | binary_image = np.uint8(target_mask) * 255 279 | resized_image = cv2.resize(binary_image, (image_crop.shape[1], image_crop.shape[0])) 280 | target_mask = resized_image > 0 281 | 282 | # masked_image = apply_mask(image_crop, target_mask) 283 | 284 | y, x = np.where(target_mask) 285 | y_min, y_max, x_min, x_max = y.min(), y.max(), x.min(), x.max() 286 | cropped_mask = target_mask[y_min:y_max + 1, x_min:x_max + 1] 287 | cropped_image = image_crop[y_min:y_max + 1, x_min:x_max + 1] 288 | masked_image_cut, cropped_mask = apply_mask(cropped_image, cropped_mask) 289 | 290 | save_masked_image(image_crop, target_mask, output_dir, filename, crop_mode_=crop_mode, kernel_size=0) 291 | else: 292 | print('未找到区域!') 293 | 294 | # plt.axis('off') 295 | # plt.show() 296 | -------------------------------------------------------------------------------- /DataAugment.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | from math import * 4 | 5 | import cv2 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | from PIL import Image, ImageEnhance 10 | from torchvision import transforms 11 | 12 | 13 | # image augment 14 | 15 | 16 | def rotate(img, angle): 17 | image = np.array(img.convert('RGB')) 18 | # 获取图像的高度和宽度 19 | height, width = image.shape[:2] 20 | 21 | # 计算旋转中心点 22 | center = (width // 2, height // 2) 23 | 24 | # 获取旋转矩阵 25 | rotation_matrix = cv2.getRotationMatrix2D(center, angle, 1.0) 26 | 27 | # 对图像进行旋转,使用白色作为填充颜色 28 | rotated_image = cv2.warpAffine(image, rotation_matrix, (width, height), borderValue=(255, 255, 255)) 29 | 30 | # 将图像转换为灰度图 31 | gray = cv2.cvtColor(rotated_image, cv2.COLOR_BGR2GRAY) 32 | 33 | # 寻找图像的边界 34 | _, thresh = cv2.threshold(gray, 1, 255, cv2.THRESH_BINARY) 35 | contours, hierarchy = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) 36 | 37 | # 计算图像的边界框 38 | x, y, w, h = cv2.boundingRect(contours[0]) 39 | 40 | # 裁剪图像 41 | cropped_image = rotated_image[y: y + h, x: x + w] 42 | 43 | return Image.fromarray(cropped_image) 44 | 45 | 46 | # tensor augment 47 | 48 | def augment_image_tensor(image_tensor, zoom_factors=[], target_size=(512, 512)): 49 | # 转换为PIL图像 50 | image = transforms.ToPILImage()(image_tensor) 51 | 52 | # 保存所有增强后的图像,先放入原始图像以及镜像后的图像 53 | augmented_images = [ 54 | transforms.Resize(target_size)(image_tensor), 55 | # transforms.Resize(target_size)(transforms.ToTensor()(transforms.functional.hflip(image))) 56 | ] 57 | 58 | rotations = [_ * 10 for _ in range(35)] 59 | 60 | # 对原始图像应用35次10度旋转/镜像 61 | for rotation in rotations: 62 | random_deviation = random.uniform(-3, 3) 63 | 64 | # 将随机偏差添加到旋转角度上 65 | adjusted_rotation = round(rotation + random_deviation) 66 | rotated_image = rotate(image, adjusted_rotation) 67 | rotated_image = rotated_image.resize(target_size, resample=Image.Resampling.BICUBIC) 68 | 69 | rotated_tensor = transforms.ToTensor()(rotated_image) 70 | augmented_images.append(rotated_tensor) 71 | # augmented_images.append(transforms.ToTensor()(transforms.functional.hflip(rotated_image))) 72 | 73 | # 获取图像尺寸 74 | width, height = image.size 75 | # # 对不同倍率裁剪后的图像应用11次30度旋转和镜像 76 | for zoom_factor in zoom_factors: 77 | # 计算裁剪区域的坐标 78 | crop_size = int(min(width, height) / zoom_factor) 79 | left = (width - crop_size) // 2 80 | top = (height - crop_size) // 2 81 | right = left + crop_size 82 | bottom = top + crop_size 83 | 84 | # 裁剪图像并调整大小 85 | cropped_image = image.crop((left, top, right, bottom)) 86 | resized_image = cropped_image.resize(target_size, resample=Image.Resampling.BICUBIC) 87 | 88 | # 保存裁剪后的图像及其镜像 89 | resized_tensor = transforms.ToTensor()(resized_image) 90 | augmented_images.append(resized_tensor) 91 | # augmented_images.append(transforms.ToTensor()(transforms.functional.hflip(resized_image))) 92 | 93 | # 定义11次30度旋转的变换 94 | rotations = [_ * 30 for _ in range(11)] 95 | 96 | # 对裁剪后的图像应用11次30度旋转和镜像 97 | for rotation in rotations: 98 | # 生成一个-5到5之间的随机偏差 99 | random_deviation = random.uniform(-5, 5) 100 | 101 | # 将随机偏差添加到旋转角度上 102 | adjusted_rotation = rotation + random_deviation 103 | 104 | rotated_resized_image = rotate(resized_image, adjusted_rotation) 105 | rotated_resized_image = rotated_resized_image.resize(target_size, resample=Image.Resampling.BICUBIC) 106 | # display_image(rotated_resized_image) 107 | 108 | rotated_resized_tensor = transforms.ToTensor()(rotated_resized_image) 109 | augmented_images.append(rotated_resized_tensor) 110 | # augmented_images.append(transforms.ToTensor()(transforms.functional.hflip(rotated_resized_image))) 111 | 112 | # 随机调整亮度、对比度和饱和度 113 | for _ in range(3): 114 | brightness_factor = random.uniform(0.8, 1.2) 115 | contrast_factor = random.uniform(0.8, 1.2) 116 | saturation_factor = random.uniform(0.8, 1.2) 117 | 118 | enhanced_image = image.copy() 119 | enhancer = ImageEnhance.Brightness(enhanced_image) 120 | enhanced_image = enhancer.enhance(brightness_factor) 121 | enhancer = ImageEnhance.Contrast(enhanced_image) 122 | enhanced_image = enhancer.enhance(contrast_factor) 123 | enhancer = ImageEnhance.Color(enhanced_image) 124 | enhanced_image = enhancer.enhance(saturation_factor) 125 | 126 | enhanced_tensor = transforms.Resize(target_size)(transforms.ToTensor()(enhanced_image)) 127 | augmented_images.append(enhanced_tensor) 128 | 129 | # # 随机平移 130 | # for _ in range(2): 131 | # max_translation = 0.1 132 | # translate_x = int(random.uniform(-max_translation, max_translation) * width) 133 | # translate_y = int(random.uniform(-max_translation, max_translation) * height) 134 | # translated_image = transforms.functional.affine(image, 0, (translate_x, translate_y), 1.0, 0) 135 | # display_image(translated_image) 136 | # 137 | # translated_tensor = transforms.Resize(target_size)(transforms.ToTensor()(translated_image)) 138 | # augmented_images.append(translated_tensor) 139 | # 140 | # # 随机缩放 141 | # for _ in range(2): 142 | # min_scale = 0.9 143 | # max_scale = 1.1 144 | # scale_factor = random.uniform(min_scale, max_scale) 145 | # scaled_image = transforms.functional.affine(image, 0, (0, 0), scale_factor, 0) 146 | # 147 | # scaled_tensor = transforms.Resize(target_size)(transforms.ToTensor()(scaled_image)) 148 | # augmented_images.append(scaled_tensor) 149 | 150 | return augmented_images 151 | 152 | 153 | def do_nothing(image_tensor, target_size=(512, 512)): 154 | # 保存所有增强后的图像,先放入原始图像以及镜像后的图像 155 | augmented_images = [ 156 | transforms.Resize(target_size)(image_tensor), 157 | ] 158 | return augmented_images 159 | 160 | 161 | def limit_size(image_tensor): 162 | # 定义转换链 163 | transform = transforms.Compose([ 164 | transforms.ToPILImage(), # 将张量转换为PIL图像 165 | transforms.Resize((512, 512)), # 调整图像大小为512x512 166 | transforms.Grayscale(num_output_channels=1), # 转换为单通道灰度图像 167 | transforms.ToTensor() # 将PIL图像转换回张量 168 | ]) 169 | 170 | # 应用转换链 171 | processed_tensor = transform(image_tensor) 172 | 173 | return processed_tensor 174 | 175 | 176 | def add_stripe_noise_pattern_around_image(image_tensor, pattern_width_ratio=0.1, stripe_scale=0.05, 177 | stripe_intensity=0.2): 178 | C, H, W = image_tensor.shape 179 | device = image_tensor.device 180 | 181 | pattern_width = int(min(H, W) * pattern_width_ratio) 182 | 183 | # 创建一个条纹状噪声贴图 184 | noise_map = torch.zeros((H, W), device=device) 185 | for i in range(H): 186 | noise_map[i, :] = torch.sin(torch.tensor(i, device=device) * stripe_scale) * stripe_intensity 187 | 188 | # 对噪声贴图进行截取,只保留四周 189 | mask = torch.zeros((H, W), device=device) 190 | mask[:pattern_width, :] = 1 191 | mask[-pattern_width:, :] = 1 192 | mask[:, :pattern_width] = 1 193 | mask[:, -pattern_width:] = 1 194 | noise_map = noise_map * mask 195 | 196 | # 使用广播机制将噪声贴图应用于所有通道 197 | image_with_pattern = torch.clamp(image_tensor - noise_map, min=0) 198 | 199 | return image_with_pattern 200 | 201 | 202 | def add_gaussian_noise(image_tensor, noise_std=0.05): 203 | noise = torch.randn_like(image_tensor) * noise_std 204 | noisy_image = image_tensor + noise 205 | noisy_image = torch.clamp(noisy_image, 0, 1) 206 | return noisy_image 207 | 208 | 209 | def add_salt_pepper_noise(image_tensor, noise_density=0.05): 210 | num_pixels = int(noise_density * image_tensor.size(1) * image_tensor.size(2)) 211 | coords = torch.randint(0, image_tensor.size(1), (num_pixels, 2), device=image_tensor.device) 212 | values = torch.randint(0, 2, (num_pixels, image_tensor.size(0)), device=image_tensor.device).float() 213 | noisy_image = image_tensor.clone() 214 | noisy_image[:, coords[:, 0], coords[:, 1]] = values.transpose(0, 1) 215 | return noisy_image 216 | 217 | 218 | def grid_mask(image_tensor, num_blocks_range=(12, 24), block_size_range=(512 // 8, 512 // 5), 219 | fill_value_range=(0, 0.3)): 220 | """ 221 | 在图像上按规律排布随机数量的小方块 222 | :param image_tensor: 输入图像,PyTorch Tensor,shape 为 (C, H, W) 223 | :param num_blocks_range: 小方块数量范围,默认为 (12, 24) 224 | :param block_size_range: 小方块边长范围,默认为 (512//8, 512//5) 225 | :param fill_value_range: 掩码填充值范围,默认为 (0, 0.3) 226 | :return: 应用 Grid Mask 后的图像 227 | """ 228 | image_tensor = image_tensor.clone() 229 | _, h, w = image_tensor.shape 230 | 231 | num_blocks = torch.randint(num_blocks_range[0], num_blocks_range[1], (1,)).item() 232 | block_size = torch.randint(block_size_range[0], block_size_range[1], (1,)).item() 233 | 234 | num_rows = ceil(sqrt(num_blocks)) 235 | num_cols = ceil(num_blocks / num_rows) 236 | 237 | grid_size_h = h // num_rows 238 | grid_size_w = w // num_cols 239 | 240 | for i in range(num_rows): 241 | for j in range(num_cols): 242 | if torch.rand(1).item() < 0.5 and (i * num_cols + j) < num_blocks: 243 | start_h = i * grid_size_h + (grid_size_h - block_size) // 2 244 | start_w = j * grid_size_w + (grid_size_w - block_size) // 2 245 | end_h = start_h + block_size 246 | end_w = start_w + block_size 247 | 248 | # 随机选择填充值 249 | fill_value = torch.rand(1).item() * (fill_value_range[1] - fill_value_range[0]) + fill_value_range[0] 250 | image_tensor[..., start_h:end_h, start_w:end_w] = fill_value 251 | 252 | return image_tensor 253 | 254 | 255 | def add_gray_stripes(image_tensor, num_stripes_range=(1, 3), stripe_width=35, gray_range=(0.3, 0.55)): 256 | """ 257 | 随机向图像中添加灰色长条 258 | :param image_tensor: 输入图像, PyTorch Tensor, shape 为 (C, H, W) 259 | :param num_stripes_range: 随机生成的长条数量范围 260 | :param stripe_width: 长条的宽度(像素) 261 | :param gray_range: 灰度值的范围,表示为0到1之间的浮点数 262 | :return: 添加灰色长条后的图像 263 | """ 264 | C, H, W = image_tensor.shape 265 | 266 | # 随机生成长条的数量 267 | num_stripes = random.randint(num_stripes_range[0], num_stripes_range[1]) 268 | 269 | image_with_stripes = image_tensor.clone() 270 | 271 | for _ in range(num_stripes): 272 | # 随机决定长条的方向(0为横向,1为纵向,2为45度,3为30度,4为60度) 273 | direction = random.randint(0, 4) 274 | # 随机选择灰度值 275 | gray_value = random.uniform(gray_range[0], gray_range[1]) 276 | 277 | if direction == 0: # 横向长条 278 | y = random.randint(0, H - 1) 279 | image_with_stripes[:, max(0, y - stripe_width // 2):min(H, y + stripe_width // 2), :] = gray_value 280 | elif direction == 1: # 纵向长条 281 | x = random.randint(0, W - 1) 282 | image_with_stripes[:, :, max(0, x - stripe_width // 2):min(W, x + stripe_width // 2)] = gray_value 283 | elif direction == 2: # 45度长条 284 | for i in range(max(-W, -H), min(W, H)): 285 | if 0 <= i < W and 0 <= H - i - 1 < H: 286 | image_with_stripes[:, max(0, H - i - 1 - stripe_width // 2):min(H, H - i - 1 + stripe_width // 2), 287 | max(0, i - stripe_width // 2):min(W, i + stripe_width // 2)] = gray_value 288 | elif direction == 3: # 30度长条 289 | for i in range(max(-W, -int(H / math.sqrt(3))), min(W, int(H / math.sqrt(3)))): 290 | if 0 <= i < W and 0 <= int(H - i * math.sqrt(3)) - 1 < H: 291 | image_with_stripes[:, max(0, int(H - i * math.sqrt(3)) - 1 - stripe_width // 2):min(H, 292 | int(H - i * math.sqrt( 293 | 3)) - 1 + stripe_width // 2), 294 | max(0, i - stripe_width // 2):min(W, i + stripe_width // 2)] = gray_value 295 | else: # 60度长条 296 | for i in range(max(-W, -int(H * math.sqrt(3))), min(W, int(H * math.sqrt(3)))): 297 | if 0 <= i < W and 0 <= int(H - i / math.sqrt(3)) - 1 < H: 298 | image_with_stripes[:, max(0, int(H - i / math.sqrt(3)) - 1 - stripe_width // 2):min(H, 299 | int(H - i / math.sqrt( 300 | 3)) - 1 + stripe_width // 2), 301 | max(0, i - stripe_width // 2):min(W, i + stripe_width // 2)] = gray_value 302 | 303 | return image_with_stripes 304 | 305 | 306 | def add_central_distortion(image_tensor, distortion_strength_range=(0.2, 0.3), radius_range=(0.35, 0.55)): 307 | """ 308 | 在图像中心添加径向扭曲 309 | :param image_tensor: 输入图像, PyTorch Tensor, shape 为 (C, H, W) 310 | :param distortion_strength_range: 扭曲强度的随机范围 311 | :param radius_range: 影响半径的随机范围,以图像宽高的最小值的比例给出 312 | :return: 添加扭曲后的图像 313 | """ 314 | # 确保图像在正确的设备上 315 | device = image_tensor.device 316 | 317 | C, H, W = image_tensor.shape 318 | center_x, center_y = W / 2, H / 2 319 | 320 | # 随机生成扭曲强度和影响半径 321 | distortion_strength = torch.rand(1).item() * (distortion_strength_range[1] - distortion_strength_range[0]) + \ 322 | distortion_strength_range[0] 323 | radius = torch.rand(1).item() * (radius_range[1] - radius_range[0]) + radius_range[0] 324 | radius = min(W, H) * radius 325 | 326 | # 生成像素坐标网格 327 | xv, yv = torch.meshgrid(torch.arange(W, device=device), torch.arange(H, device=device), indexing="ij") 328 | 329 | # 计算每个像素距离中心的距离 330 | dist = torch.sqrt((xv - center_x) ** 2 + (yv - center_y) ** 2) 331 | 332 | # 应用径向扭曲模型 333 | factor = torch.ones_like(dist) 334 | within_radius = dist < radius 335 | factor[within_radius] = 1 + distortion_strength * torch.sin(torch.pi * dist[within_radius] / radius) 336 | 337 | # 计算扭曲后的坐标 338 | xv_new = (xv - center_x) * factor + center_x 339 | yv_new = (yv - center_y) * factor + center_y 340 | 341 | # 创建坐标网格以进行采样 342 | grid = torch.stack([yv_new / (H - 1) * 2 - 1, xv_new / (W - 1) * 2 - 1], dim=2) 343 | 344 | # 重采样图像 345 | distorted_image = F.grid_sample(image_tensor.unsqueeze(0), grid.unsqueeze(0), mode='bilinear', 346 | padding_mode='border', 347 | align_corners=True).squeeze(0) 348 | 349 | return distorted_image 350 | 351 | 352 | def random_kernel_filter(image_tensor, kernel_size_range=(5, 11), sigma_range=(0.5, 0.9)): 353 | """ 354 | 对图像应用随机卷积核滤波 355 | :param image_tensor: 输入图像, PyTorch Tensor, shape 为 (C, H, W) 356 | :param kernel_size_range: 卷积核大小的随机范围,必须为奇数 357 | :param sigma_range: 高斯卷积核的标准差随机范围 358 | :return: 应用卷积核滤波后的图像 359 | """ 360 | # 确保图像在正确的设备上 361 | device = image_tensor.device 362 | 363 | # 随机生成卷积核大小和标准差 364 | kernel_size = torch.randint(kernel_size_range[0] // 2, kernel_size_range[1] // 2 + 1, (1,)).item() * 2 + 1 365 | sigma = torch.rand(1).item() * (sigma_range[1] - sigma_range[0]) + sigma_range[0] 366 | 367 | # 生成高斯卷积核 368 | kernel = gaussian_kernel(kernel_size, sigma).to(device) 369 | 370 | # 对图像应用卷积 371 | filtered_image = F.conv2d(image_tensor.unsqueeze(0), kernel.unsqueeze(0).unsqueeze(0), 372 | padding=kernel_size // 2).squeeze(0) 373 | 374 | return filtered_image 375 | 376 | 377 | def gaussian_kernel(kernel_size, sigma): 378 | """ 379 | 生成高斯卷积核 380 | :param kernel_size: 卷积核大小,必须为奇数 381 | :param sigma: 高斯分布的标准差 382 | :return: 高斯卷积核, PyTorch Tensor, shape 为 (kernel_size, kernel_size) 383 | """ 384 | assert kernel_size % 2 == 1, "Kernel size must be odd" 385 | 386 | # 生成坐标网格 387 | x = torch.arange(-(kernel_size // 2), (kernel_size // 2) + 1, dtype=torch.float32) 388 | y = torch.arange(-(kernel_size // 2), (kernel_size // 2) + 1, dtype=torch.float32) 389 | xy_grid = torch.stack(torch.meshgrid(x, y), dim=-1) 390 | 391 | # 计算高斯分布 392 | kernel = torch.exp(-torch.sum(xy_grid ** 2, dim=-1) / (2 * sigma ** 2)) 393 | kernel = kernel / kernel.sum() 394 | 395 | return kernel 396 | 397 | 398 | def rotate_image(image_tensor, angle): 399 | """ 400 | 旋转图像 401 | :param image_tensor: 输入图像,PyTorch Tensor,shape 为 (C, H, W) 402 | :param angle: 旋转角度,可选值为 0, 90, 180, 270 403 | :return: 旋转后的图像 404 | """ 405 | if angle == 0: 406 | return image_tensor 407 | elif angle == 90: 408 | return torch.rot90(image_tensor, k=1, dims=[-2, -1]) 409 | elif angle == 180: 410 | return torch.rot90(image_tensor, k=2, dims=[-2, -1]) 411 | elif angle == 270: 412 | return torch.rot90(image_tensor, k=3, dims=[-2, -1]) 413 | else: 414 | raise ValueError("Invalid rotation angle. Supported angles are 0, 90, 180, and 270.") 415 | -------------------------------------------------------------------------------- /Datasets.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | 3 | 4 | class TensorDataset(Dataset): 5 | def __init__(self, images, transform=None): 6 | """ 7 | Args: 8 | images (list): List of PIL images. 9 | transform (callable, optional): Optional transform to be applied on a sample. 10 | """ 11 | self.images = images 12 | self.transform = transform 13 | 14 | def __len__(self): 15 | return len(self.images) 16 | 17 | def __getitem__(self, idx): 18 | image = self.images[idx] 19 | return image 20 | -------------------------------------------------------------------------------- /GAN.py: -------------------------------------------------------------------------------- 1 | from torchvision.utils import save_image 2 | 3 | from AutoEncoder import * 4 | from DataAugment import * 5 | from ImagePreProcessing import img_pre_processing_gray 6 | from ModelLoader import ModelLoader 7 | 8 | adv_weight = 900 9 | recon_weight = 0.4 10 | kl_weight = 1.0 11 | regen_weight = 0.6 12 | 13 | 14 | class ResidualBlock(nn.Module): 15 | def __init__(self, in_features): 16 | super(ResidualBlock, self).__init__() 17 | 18 | self.block = nn.Sequential( 19 | nn.ReflectionPad2d(1), 20 | nn.Conv2d(in_features, in_features, 3), 21 | nn.InstanceNorm2d(in_features), 22 | nn.LeakyReLU(0.2, inplace=True), 23 | nn.ReflectionPad2d(1), 24 | nn.Conv2d(in_features, in_features, 3), 25 | nn.InstanceNorm2d(in_features) 26 | ) 27 | 28 | def forward(self, x): 29 | return x + self.block(x) 30 | 31 | 32 | class Discriminator(nn.Module): 33 | def __init__(self, input_nc=input_channel, ndf=64): 34 | super(Discriminator, self).__init__() 35 | 36 | model = [ 37 | nn.ReflectionPad2d(1), 38 | nn.Conv2d(input_nc, ndf, kernel_size=4, stride=2, padding=0), 39 | nn.LeakyReLU(0.2, True), 40 | ResidualBlock(ndf), 41 | 42 | nn.Conv2d(ndf, ndf * 2, kernel_size=4, stride=2, padding=1), 43 | nn.InstanceNorm2d(ndf * 2), 44 | nn.LeakyReLU(0.2, True), 45 | ResidualBlock(ndf * 2), 46 | 47 | nn.Conv2d(ndf * 2, ndf * 4, kernel_size=4, stride=2, padding=1), 48 | nn.InstanceNorm2d(ndf * 4), 49 | nn.LeakyReLU(0.2, True), 50 | ResidualBlock(ndf * 4), 51 | 52 | nn.Conv2d(ndf * 4, ndf * 8, kernel_size=4, stride=2, padding=1), 53 | nn.InstanceNorm2d(ndf * 8), 54 | nn.LeakyReLU(0.2, True), 55 | ResidualBlock(ndf * 8), 56 | 57 | nn.Conv2d(ndf * 8, ndf * 8, kernel_size=4, stride=2, padding=1), 58 | nn.InstanceNorm2d(ndf * 8), 59 | nn.LeakyReLU(0.2, True), 60 | ResidualBlock(ndf * 8), 61 | 62 | nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=1, padding=1), 63 | nn.Sigmoid(), 64 | nn.AdaptiveAvgPool2d(1) # 新增一个自适应平均池化层,压缩为1x1 65 | ] 66 | 67 | self.model = nn.Sequential(*model) 68 | 69 | def forward(self, input): 70 | return self.model(input).view(-1) # 将输出拉平为一维向量 71 | 72 | 73 | def discriminator_loss(real_out, fake_out): 74 | d_loss = -1 * (torch.log(real_out) + torch.log(1 - fake_out)) 75 | return d_loss.mean() 76 | 77 | 78 | class VAEGANModelLoader(ModelLoader): 79 | def __init__(self, train_dataset, test_dataset, batch_size, model_path: str, discriminator_path: str, if_early_stop=False, 80 | debug_mode=False): 81 | super().__init__(train_dataset, test_dataset, batch_size, model_path, if_early_stop, debug_mode) 82 | print('-' * 10, 'Loading VAE-GAN model', '-' * 10) 83 | # encoder 84 | self.latent_dim = latent_dim # latent vector dimension 85 | encoder = VAEEncoder(self.latent_dim) 86 | # decoder 87 | decoder = VAEDecoder(self.latent_dim) 88 | # VAE 89 | self.model = VAEModel(encoder, decoder).to(self.device) 90 | self.lr = 1e-4 # learning rate 91 | # optimizer 92 | self.G_optimizer = optim.Adam(self.model.parameters(), lr=self.lr) 93 | self.G_scheduler = CosineAnnealingLR(self.G_optimizer, T_max=20, eta_min=1e-5) 94 | 95 | # discriminator 96 | self.discriminator = Discriminator(input_channel).to(self.device) 97 | 98 | self.D_lr = 1e-8 99 | self.D_optimizer = optim.Adam(self.discriminator.parameters(), lr=self.D_lr) 100 | self.D_scheduler = CosineAnnealingLR(self.D_optimizer, T_max=20, eta_min=0) 101 | 102 | self.criterion = nn.BCELoss() 103 | self.criterion_GAN = nn.MSELoss() 104 | self.criterion_cycle = nn.L1Loss() 105 | self.criterion_identity = nn.L1Loss() 106 | 107 | # self.D_scheduler = CosineAnnealingLR(self.D_optimizer, T_max=50, eta_min=1e-5) 108 | 109 | self.discriminator_path = discriminator_path 110 | 111 | self.real_labels = torch.ones(batch_size).to(self.device) 112 | self.fake_labels = torch.zeros(batch_size).to(self.device) 113 | 114 | # load exist model 115 | self.load_model() 116 | 117 | def load_model(self): 118 | # load model weight 119 | model_dir = os.path.dirname(self.model_path) 120 | print('Try to load model from', self.model_path) 121 | # 检查模型文件夹路径是否存在 122 | if not os.path.exists(model_dir): 123 | # 不存在就创建新的目录 124 | os.makedirs(model_dir) 125 | print(f"Created directory '{model_dir}' for saving models.") 126 | if os.path.isfile(self.model_path): 127 | try: 128 | self.model.load_state_dict(torch.load(self.model_path, map_location=self.device)) 129 | print("VAE model loaded successfully from '{}'".format(self.model_path)) 130 | except Exception as e: 131 | print("Failed to load VAE model. Starting from scratch. Error: ", e) 132 | else: 133 | print("No saved model found at '{}'. Starting from scratch.".format(self.model_path)) 134 | if os.path.isfile(self.discriminator_path): 135 | try: 136 | self.discriminator.load_state_dict(torch.load(self.discriminator_path, map_location=self.device)) 137 | print("Discriminator A model loaded successfully from '{}'".format(self.discriminator_path)) 138 | except Exception as e: 139 | print("Failed to load discriminator A model. Starting from scratch. Error: ", e) 140 | else: 141 | print("No saved model found at '{}'. Starting from scratch.".format(self.discriminator_path)) 142 | 143 | def _test_epoch_vae(self): 144 | self.model.eval() 145 | test_loss = 0 146 | with torch.no_grad(): 147 | for i, x in enumerate(self.test_iterator): 148 | # reshape the data 149 | x = x.view(-1, input_channel, 512, 512) 150 | x = x.to(self.device) 151 | 152 | # forward pass 153 | x_sample, z_mu, z_var = self.model(x) 154 | 155 | recon_loss = F.mse_loss(x_sample, x, reduction='sum') 156 | kl_loss = 0.5 * torch.sum(torch.exp(z_var) + z_mu ** 2 - 1.0 - z_var) 157 | 158 | g_loss = recon_loss * recon_weight + kl_loss * kl_weight 159 | 160 | test_loss += g_loss.item() 161 | 162 | print(f'VAE test batch {i}, g_loss: {g_loss.item() / self.batch_size}') 163 | 164 | return test_loss / len(self.test_dataset) 165 | 166 | def save_dis_model(self): 167 | """ 168 | 保存判别器的权重。 169 | """ 170 | torch.save(self.discriminator.state_dict(), self.discriminator_path) 171 | print(f'Model saved to {self.discriminator_path}') 172 | 173 | def train(self, epochs=100, test_interval=10): 174 | train_losses_G = [] 175 | train_losses_D = [] 176 | test_losses_G = [] 177 | 178 | for epoch in range(epochs): 179 | torch.cuda.empty_cache() 180 | train_loss_G = 0 181 | train_loss_D = 0 182 | 183 | for batch, x in enumerate(self.train_iterator): 184 | # 转换图像张量的形状并移至指定的设备 185 | x = x.view(-1, input_channel, 512, 512) 186 | x = x.to(self.device) 187 | 188 | # 对每张图片应用随机旋转 189 | rotation_angles = torch.randint(0, 4, (x.size(0),)) * 90 190 | for i in range(x.size(0)): 191 | x[i] = rotate_image(x[i], rotation_angles[i].item()) 192 | 193 | # 复制x到x_,以便在x_上添加噪声,同时保持x不变 194 | x_ = x.clone() 195 | 196 | for idx in range(x_.size(0)): 197 | image = x_[idx] # 获取单张图片 198 | 199 | # image_ = image.clone() 200 | 201 | random_judgement = random.random() 202 | if random_judgement <= 0.25: 203 | # 增加随机数量、随机大小的黑色小方块 204 | # print('加入小方块') 205 | image = grid_mask(image) 206 | elif random_judgement <= 0.5: 207 | # print('加入长条') 208 | image = add_gray_stripes(image) 209 | else: 210 | # print('无修改') 211 | pass 212 | 213 | random_judgement = random.random() 214 | if random_judgement <= 0.15: 215 | # 加椒盐噪声 216 | # print('椒盐噪声') 217 | image = add_salt_pepper_noise(image) 218 | elif random_judgement <= 0.3: 219 | # 加高斯噪声 220 | # print('高斯噪声') 221 | image = add_gaussian_noise(image) 222 | elif random_judgement <= 0.5: 223 | # 在图像中心附近做一点扭曲 224 | # print('加入扭曲') 225 | image = add_central_distortion(image) 226 | elif random_judgement <= 0.7: 227 | # 随机向图像中添加一块模糊遮罩 228 | # print('加入模糊') 229 | image = random_kernel_filter(image) 230 | elif random_judgement <= 0.9: 231 | image = add_stripe_noise_pattern_around_image(image) 232 | else: 233 | # print('无噪声') 234 | pass 235 | 236 | x_[idx] = image 237 | 238 | # 显示原始图片和带噪声的图片 239 | # fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5)) 240 | # ax1.imshow(image_.squeeze().cpu().numpy(), cmap="gray") 241 | # ax1.set_title("Original Image") 242 | # ax1.axis("off") 243 | # ax2.imshow(image.squeeze().cpu().numpy(), cmap="gray") 244 | # ax2.set_title("Noisy Image") 245 | # ax2.axis("off") 246 | # plt.tight_layout() 247 | # plt.show() 248 | 249 | # 生成器训练 250 | self.G_optimizer.zero_grad() 251 | 252 | # 生成两张不同风格的图像y和z 253 | y, z_mu_y, z_var_y = self.model(x_) 254 | z, z_mu_z, z_var_z = self.model(y) 255 | 256 | # 计算重建损失和KL散度损失 257 | recon_loss_y = F.mse_loss(y, x, reduction='sum') 258 | kl_loss_y = 0.5 * torch.sum(torch.exp(z_var_y) + z_mu_y ** 2 - 1.0 - z_var_y) 259 | # 再次重建损失 260 | regen_loss = F.mse_loss(x, z, reduction='sum') 261 | kl_loss_z = 0.5 * torch.sum(torch.exp(z_var_z) + z_mu_z ** 2 - 1.0 - z_var_z) 262 | 263 | # 对生成的图像进行判别 264 | loss_G_adv = self.criterion_GAN(self.discriminator(y), self.real_labels) + \ 265 | self.criterion_GAN(self.discriminator(z), self.real_labels) 266 | 267 | # 计算生成器的总损失 268 | loss_G = recon_loss_y * recon_weight + ( 269 | kl_loss_y + kl_loss_z) * kl_weight + loss_G_adv * adv_weight + regen_loss * regen_weight 270 | loss_G.backward() 271 | self.G_optimizer.step() 272 | 273 | # 判别器训练 274 | self.D_optimizer.zero_grad() 275 | real_loss = self.criterion_GAN(self.discriminator(x), self.real_labels) 276 | fake_loss_y = self.criterion_GAN(self.discriminator(y.detach()), self.fake_labels) 277 | fake_loss_z = self.criterion_GAN(self.discriminator(z.detach()), self.fake_labels) 278 | loss_D = (real_loss + fake_loss_y + fake_loss_z) / 3 279 | loss_D.backward() 280 | self.D_optimizer.step() 281 | 282 | train_loss_G += loss_G.item() 283 | train_loss_D += loss_D.item() 284 | 285 | # 打印损失 286 | print(f"Epoch [{epoch + 1}/{epochs}], Batch [{batch + 1}/{len(self.train_iterator)}], " 287 | f"D Loss: {loss_D.item():.4f}, G Loss: {loss_G.item() / self.batch_size:.4f}") 288 | 289 | self.G_scheduler.step() 290 | self.D_scheduler.step() 291 | 292 | # 计算平均训练损失 293 | train_loss_G /= len(self.train_dataset) 294 | train_loss_D /= len(self.train_dataset) 295 | 296 | train_losses_G.append(train_loss_G) 297 | train_losses_D.append(train_loss_D) 298 | 299 | test_loss_g = self._test_epoch_vae() 300 | test_losses_G.append(test_loss_g) 301 | 302 | print( 303 | f'Epoch {epoch + 1}, Train G Loss: {train_loss_G:.6f}, Train D Loss: {train_loss_D:.6f}, ' 304 | f'Test G Loss: {test_loss_g}') 305 | 306 | # 保存模型 307 | self.save_model() 308 | self.save_dis_model() 309 | 310 | # 按照间隔测试模型 311 | if (epoch + 1) % test_interval == 0: 312 | # 获取test_imgs目录下的所有图片文件 313 | test_imgs_dir = "./test_imgs/" 314 | test_imgs_files = [f for f in os.listdir(test_imgs_dir) if f.endswith(".bmp") or f.endswith(".png")] 315 | 316 | # 对test_imgs目录下的图片进行重建测试 317 | for _, file in enumerate(test_imgs_files): 318 | file_path = os.path.join(test_imgs_dir, file) 319 | output_name = f"epoch_{epoch + 1}_test_img_{_ + 1}" 320 | self.regenerate_test(Image.open(file_path), output_name) 321 | 322 | # 将误差写入txt文件 323 | with open('cycle_gan_losses.txt', 'w') as f: 324 | f.write('Epoch\tTrain_G_Loss\tTrain_D_Loss\tTest_G_Loss\n') 325 | for batch in range(epochs): 326 | f.write( 327 | f'{batch}\t{train_losses_G[batch]:.6f}\t{train_losses_D[batch]:.6f}\t{test_losses_G[batch]}\n') 328 | 329 | def regenerate_test(self, input_image: Image, file_name: str): 330 | """ 331 | 根据输入图片解码并重新生成测试模型效果 332 | :param input_image: 输入图片 333 | :param file_name: 输出图片名称 334 | :return: None 335 | """ 336 | print('-' * 3, 'regenerate_test', '-' * 3) 337 | if not os.path.exists('./train_result/'): 338 | os.makedirs('./train_result/') 339 | # 定义预处理转换链 340 | if input_channel == 3: 341 | transform = transforms.Compose([ 342 | transforms.Resize((512, 512)), # 将图像调整为512x512 343 | transforms.Lambda(convert_to_rgb), # 确保图像为三通道 344 | transforms.ToTensor() # 将图像转换为PyTorch张量 345 | ]) 346 | elif input_channel == 1: 347 | cv_image = np.array(input_image.convert('RGB')) 348 | # 转换为BGR格式 349 | cv_image = cv_image[:, :, ::-1] 350 | cv_image = img_pre_processing_gray(cv_image) # 返回二值化后的图 351 | input_image = Image.fromarray(cv_image) 352 | transform = transforms.Compose([ 353 | transforms.Resize((512, 512)), # 将图像调整为512x512 354 | transforms.ToTensor() # 将图像转换为PyTorch张量 355 | ]) 356 | # 应用预处理转换链 357 | input_tensor = transform(input_image) 358 | 359 | # 添加批次维度并将图像输入模型 360 | input_tensor = input_tensor.to(self.device).unsqueeze(0) # 添加批次维度,即从C x H x W变为1 x C x H x W 361 | 362 | with torch.no_grad(): 363 | # 编码图像,获取潜在空间的均值和方差对数 364 | z_mu, z_log_var = self.model.encoder(input_tensor) 365 | 366 | # 从标准正态分布中采样epsilon 367 | # std = torch.exp(z_log_var / 2) 368 | # eps = torch.randn_like(std) 369 | # z = z_mu + eps * std 370 | 371 | # 解码 372 | regenerated_image = self.model.decoder(z_mu) 373 | 374 | regenerated_image = regenerated_image.cpu() 375 | # 图片名称:添加参数picture_name和索引i 376 | filename = f'./train_result/{file_name}.png' 377 | save_image(regenerated_image, filename, normalize=True) 378 | print('result saved to', filename) 379 | 380 | def test(self): 381 | # 获取test_imgs目录下的所有图片文件 382 | test_imgs_dir = "./test_imgs/" 383 | test_imgs_files = [f for f in os.listdir(test_imgs_dir) if f.endswith(".bmp") or f.endswith(".png")] 384 | 385 | # 对test_imgs目录下的图片进行重建测试 386 | for i, file in enumerate(test_imgs_files): 387 | file_path = os.path.join(test_imgs_dir, file) 388 | output_name = f"test_result_{i + 1}" 389 | self.regenerate_test(Image.open(file_path), output_name) 390 | -------------------------------------------------------------------------------- /GenerateTestDataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | 5 | from ImageTools import * 6 | 7 | 8 | def load_test_dataset(test_imgs_dir='./test_images/'): 9 | # 检查测试图片目录是否存在 10 | if not os.path.exists(test_imgs_dir): 11 | print(f"Test images directory not found: {test_imgs_dir}") 12 | return None 13 | 14 | # 获取目录下所有图片文件的路径 15 | img_paths = [os.path.join(test_imgs_dir, f) for f in os.listdir(test_imgs_dir) if 16 | f.endswith(('.jpg', '.jpeg', '.png', '.bmp'))] 17 | 18 | # 定义图片预处理转换 19 | transform = transforms.Compose([ 20 | transforms.Resize((512, 512)), # 调整图片大小为512x512 21 | transforms.ToTensor(), # 将图片转换为张量 22 | ]) 23 | 24 | # 加载图片并应用预处理转换 25 | test_dataset = [] 26 | for img_path in img_paths: 27 | image = Image.open(img_path).convert("RGBA") # 读取图像并转换为RGBA模式 28 | 29 | # 将RGBA图像转换为灰度图,并将透明部分填充为白色 30 | background = Image.new("RGBA", image.size, (255, 255, 255)) 31 | alpha_composite = Image.alpha_composite(background, image) 32 | gray_image = alpha_composite.convert("L") 33 | image_tensor = transform(gray_image) # 将灰度图像转换为张量 34 | test_dataset.append(image_tensor) 35 | 36 | # 将图片数据转换为张量 37 | test_dataset = torch.stack(test_dataset) 38 | print(f'Test datas have been loaded from {test_imgs_dir}') 39 | return test_dataset 40 | 41 | 42 | if __name__ == '__main__': 43 | # 调用函数加载测试图片数据集 44 | test_dataset = load_test_dataset() 45 | 46 | if test_dataset is not None: 47 | print(f"Test dataset loaded successfully. Shape: {test_dataset.shape}") 48 | # 可以将test_dataset保存到文件中,类似于提供的代码 49 | test_dataset_path = './datasets/test_dataset.pt' 50 | torch.save(test_dataset, test_dataset_path) 51 | print(f"Test dataset has been saved to {test_dataset_path}.") 52 | else: 53 | print("Failed to load test dataset.") 54 | -------------------------------------------------------------------------------- /ImagePreProcessing.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | 5 | def read_img(image_path: str) -> np.ndarray: 6 | image = cv2.imread(image_path) 7 | if image is None: 8 | print("Image not found or invalid image format") 9 | return None 10 | return image 11 | 12 | 13 | def show_img(images, win_name=None): 14 | if not isinstance(images, list): 15 | images = [images] 16 | 17 | for i, image in enumerate(images): 18 | # 调整图像大小以适应屏幕 19 | screen_res = 1280, 720 # 示例屏幕分辨率 20 | scale_width = screen_res[0] / image.shape[1] 21 | scale_height = screen_res[1] / image.shape[0] 22 | scale = min(scale_width, scale_height) 23 | 24 | window_width = int(image.shape[1] * scale) 25 | window_height = int(image.shape[0] * scale) 26 | 27 | # 为每个图像创建一个唯一的窗口 28 | if win_name is None: 29 | window_name = f'image {i}' 30 | else: 31 | window_name = win_name + f' {i}' 32 | cv2.namedWindow(window_name, cv2.WINDOW_NORMAL) 33 | cv2.resizeWindow(window_name, window_width, window_height) 34 | cv2.imshow(window_name, image) 35 | 36 | cv2.waitKey(0) 37 | cv2.destroyAllWindows() 38 | 39 | 40 | def img_pre_processing_gray(image): 41 | gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) 42 | return gray_image 43 | 44 | 45 | def img_pre_processing_binary(image): 46 | # 分离 BGR 通道 47 | b_channel, g_channel, r_channel = cv2.split(image) 48 | 49 | # 对每个通道分别应用中值模糊和自适应阈值二值化 50 | binary_channels = [] 51 | for channel in [b_channel, g_channel, r_channel]: 52 | blurred_channel = cv2.medianBlur(channel, 9) 53 | binary_channel = cv2.adaptiveThreshold(blurred_channel, 255, 54 | cv2.ADAPTIVE_THRESH_GAUSSIAN_C, 55 | cv2.THRESH_BINARY, 17, 6) 56 | binary_channels.append(binary_channel) 57 | 58 | # 合并二值化后的通道 59 | binary_image = cv2.merge(binary_channels) 60 | binary_image = cv2.cvtColor(binary_image, cv2.COLOR_BGR2GRAY) 61 | 62 | _, binary_image = cv2.threshold(binary_image, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) 63 | kernel = np.ones((5, 5), np.uint8) 64 | binary_image = cv2.erode(binary_image, kernel, iterations=1) 65 | binary_image = cv2.dilate(binary_image, kernel, iterations=1) 66 | 67 | # 应用高斯模糊到合并后的二值化图像 68 | binary_image = cv2.GaussianBlur(binary_image, (5, 5), 0) 69 | return binary_image 70 | -------------------------------------------------------------------------------- /ImageTools.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from matplotlib import pyplot as plt 3 | from torchvision import transforms 4 | 5 | 6 | def load_image_to_tensor(image_path: str): 7 | image = Image.open(image_path).convert("RGBA") # 读取图像并转换为RGB模式 8 | 9 | # 如果图像有alpha通道,将RGBA图像转换为灰度图,并将透明部分填充为白色 10 | background = Image.new("RGBA", image.size, (255, 255, 255)) 11 | alpha_composite = Image.alpha_composite(background, image) 12 | gray_image = alpha_composite.convert("L") 13 | 14 | # 显示灰度图像 15 | # plt.figure(figsize=(8, 8)) 16 | # plt.imshow(gray_image, cmap='gray') 17 | # plt.axis('off') 18 | # plt.show() 19 | 20 | transform = transforms.Compose([ 21 | transforms.ToTensor() # 将图像转换为PyTorch张量 22 | ]) 23 | 24 | image_tensor = transform(gray_image) # 将灰度图像转换为张量 25 | return image_tensor 26 | 27 | 28 | def load_image_with_alpha_channel(image_path: str): 29 | image = Image.open(image_path).convert("RGBA") # 读取图像并转换为RGB模式 30 | 31 | # 如果图像有alpha通道,将RGBA图像转换为灰度图,并将透明部分填充为白色 32 | background = Image.new("RGBA", image.size, (255, 255, 255)) 33 | alpha_composite = Image.alpha_composite(background, image) 34 | # gray_image = alpha_composite.convert("L") 35 | 36 | # return gray_image 37 | return alpha_composite 38 | 39 | 40 | def convert_image_to_tensor(image_pcb): 41 | image = image_pcb.convert("RGBA") # 读取图像并转换为RGB模式 42 | 43 | # 如果图像有alpha通道,将RGBA图像转换为灰度图,并将透明部分填充为白色 44 | background = Image.new("RGBA", image.size, (255, 255, 255)) 45 | alpha_composite = Image.alpha_composite(background, image) 46 | gray_image = alpha_composite.convert("L") 47 | 48 | # 显示灰度图像 49 | # plt.figure(figsize=(8, 8)) 50 | # plt.imshow(gray_image, cmap='gray') 51 | # plt.axis('off') 52 | # plt.show() 53 | 54 | transform = transforms.Compose([ 55 | transforms.ToTensor() # 将图像转换为PyTorch张量 56 | ]) 57 | 58 | image_tensor = transform(gray_image) # 将灰度图像转换为张量 59 | return image_tensor 60 | 61 | 62 | def display_image(image): 63 | plt.imshow(image) 64 | plt.axis('off') 65 | plt.show() 66 | 67 | 68 | def convert_to_rgb(img): 69 | return img.convert('RGB') 70 | -------------------------------------------------------------------------------- /MetaLearning.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | from torch.utils.data import DataLoader 4 | 5 | from AutoEncoder import * 6 | from DataAugment import limit_size, do_nothing 7 | from Datasets import * 8 | from VAE_GAN_train import load_image_datasets 9 | 10 | 11 | def augment_tensor_dataset(tensor_dataset): 12 | augmented_tensors = [] 13 | for tensor in tensor_dataset: 14 | # augmented_images = augment_image_tensor(tensor) 15 | augmented_images = do_nothing(tensor) 16 | for i in range(len(augmented_images)): 17 | augmented_images[i] = limit_size(augmented_images[i]) 18 | augmented_tensors.extend(augmented_images) 19 | return torch.stack(augmented_tensors) # 将列表转换回一个新的张量 20 | 21 | 22 | class MAML: 23 | def __init__(self, inner_lr, beta, d): 24 | self.device = d 25 | encoder = VAEEncoder(latent_dim) 26 | decoder = VAEDecoder(latent_dim) 27 | self.model = VAEModel(encoder, decoder).to(self.device) 28 | self.inner_lr = inner_lr 29 | self.beta = beta 30 | self.grad_clip_norm = 1.0 # 添加梯度裁剪的范数值 31 | 32 | def inner_update(self, x): 33 | x_sample, z_mu, z_var = self.model(x) 34 | inner_loss = compute_task_loss(x_sample, x, z_mu, z_var) 35 | self.model.zero_grad() 36 | inner_loss.backward() 37 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_clip_norm) # 梯度裁剪 38 | inner_optimizer = optim.Adam(self.model.parameters(), lr=self.inner_lr) 39 | inner_optimizer.step() 40 | return inner_loss 41 | 42 | def meta_update(self, num_tasks, general_loader, specific_loader, num_samples_per_task): 43 | # 将模型参数转换为浮点类型 44 | for name, param in self.model.named_parameters(): 45 | if param.data.dtype != torch.float32: 46 | param.data = param.data.float() 47 | 48 | param_dict = deepcopy(self.model.state_dict()) 49 | param_dict = {name: torch.zeros_like(param_dict[name], dtype=torch.float32, requires_grad=True) for name in 50 | param_dict} 51 | 52 | for _ in range(num_tasks): 53 | x_task = sample_task_data(general_loader, specific_loader, num_samples_per_task) 54 | x_task_viewed = x_task.view(-1, input_channel, 512, 512) 55 | 56 | self.inner_update(x_task_viewed) 57 | updated_param = deepcopy(self.model.state_dict()) 58 | 59 | x_query = sample_task_data(general_loader, specific_loader, num_samples_per_task) 60 | x_query_viewed = x_query.view(-1, input_channel, 512, 512) 61 | 62 | self.model.load_state_dict(updated_param) 63 | x_sample, z_mu, z_var = self.model(x_query_viewed) 64 | task_loss = compute_task_loss(x_sample, x_task_viewed, z_mu, z_var) 65 | print('\r\rtast_loss:', task_loss.item()) 66 | self.model.zero_grad() 67 | task_loss.backward() 68 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_clip_norm) # 梯度裁剪 69 | 70 | meta_grad = {} 71 | for name, params in zip(self.model.state_dict(), self.model.parameters()): 72 | if params.grad is not None: 73 | meta_grad[name] = torch.mean(params.grad.data) # 对梯度进行平均 74 | 75 | for name in param_dict: 76 | if name in meta_grad: 77 | param_dict[name] = param_dict[name] + meta_grad[name] * torch.ones_like( 78 | param_dict[name]) # 标量梯度乘以全1张量并累加 79 | 80 | net_params = self.model.state_dict() 81 | net_params_new = {name: net_params[name] + self.beta * param_dict[name] / num_tasks for name in net_params} 82 | self.model.load_state_dict(net_params_new) 83 | 84 | 85 | def train_maml_vae(mlma_model, general_loader, specific_loader, num_tasks, num_inner_steps, num_samples_per_task, 86 | meta_iteration): 87 | global device 88 | torch.cuda.empty_cache() 89 | print(f"Meta Iteration: {meta_iteration}") 90 | 91 | for inner_step in range(num_inner_steps): 92 | mlma_model.meta_update(num_tasks, general_loader, specific_loader, num_samples_per_task) 93 | print(f"\rInner Step: {inner_step + 1}/{num_inner_steps}") 94 | 95 | 96 | def compute_task_loss(x_recon, x, z_mu, z_var): 97 | # 重建损失 98 | recon_loss = F.mse_loss(x_recon, x, reduction='sum') 99 | 100 | # KL散度损失 101 | kl_loss = -0.5 * torch.sum(1 + z_var - z_mu.pow(2) - z_var.exp()) 102 | 103 | # 总损失 104 | task_loss = recon_loss + kl_loss 105 | return task_loss 106 | 107 | 108 | def sample_task_data(general_loader, specific_loader, num_samples): 109 | """ 110 | Sample task-specific data from the general and specific data loaders. 111 | """ 112 | general_data = next(iter(general_loader)) 113 | specific_data = next(iter(specific_loader)) 114 | 115 | general_indices = torch.randint(0, len(general_data), (num_samples,)) 116 | specific_indices = torch.randint(0, len(specific_data), (num_samples,)) 117 | 118 | task_data = torch.cat((general_data[general_indices], specific_data[specific_indices]), dim=0) 119 | global device 120 | return task_data.to(device) 121 | 122 | 123 | def seg_tensor_dataset(tensor_dataset): 124 | part1, part2 = [], [] 125 | for index, tensor in enumerate(tensor_dataset): 126 | image = transforms.ToPILImage()(tensor) 127 | if index < 5: 128 | part1.append(transforms.ToTensor()(image)) 129 | else: 130 | part2.append(transforms.ToTensor()(image)) 131 | return torch.stack(part1), torch.stack(part2) 132 | 133 | 134 | def merge_datasets(tensor_dataset_list: list): 135 | merged_dataset = [] 136 | for dataset in tensor_dataset_list: 137 | merged_dataset.extend(dataset) 138 | return torch.stack(merged_dataset) 139 | 140 | 141 | def load_datasets(zoom_factor=1.2): 142 | image_directory = "./cut_imgs/" 143 | 144 | # 读取图片并创建数据集 145 | loaded_datasets = [load_image_datasets(image_directory)] 146 | 147 | # 合并训练数据集 148 | combined_dataset = torch.utils.data.ConcatDataset(loaded_datasets) 149 | general_data = augment_tensor_dataset(combined_dataset) 150 | 151 | # 加载测试数据集 152 | test_dataset_path = './datasets/test_dataset.pt' 153 | 154 | if os.path.isfile(test_dataset_path): 155 | specific_data = torch.load(test_dataset_path) 156 | print(f"Test dataset has been loaded and merged from {test_dataset_path}.") 157 | augmented_specific_data = [] 158 | for image_tensor in specific_data: 159 | augmented_tensor = do_nothing(image_tensor, zoom_factor) 160 | augmented_specific_data.append(augmented_tensor) 161 | specific_data = augmented_specific_data 162 | 163 | else: 164 | print("Test dataset files not found.") 165 | specific_data = None 166 | return general_data, specific_data 167 | 168 | 169 | if __name__ == '__main__': 170 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 171 | # general_data: 通用训练材料 172 | # specific_data: 指定任务的训练素材 173 | # 在主函数中调用 load_datasets 函数 174 | general_data, specific_data = load_datasets() 175 | 176 | general_dataset = TensorDataset(general_data) 177 | specific_dataset = TensorDataset(specific_data) 178 | 179 | general_loader = DataLoader(general_dataset, batch_size=5, shuffle=True) 180 | specific_loader = DataLoader(specific_dataset, batch_size=5, shuffle=True) 181 | 182 | # 定义MAML模型 183 | maml_model = MAML(1e-5, 1e-8, device) 184 | 185 | model_path = './saved_model/VAE_cold_start.pth' 186 | model_dir = os.path.dirname(model_path) 187 | print('Try to load model from', model_path) 188 | # 检查模型文件夹路径是否存在 189 | if not os.path.exists(model_dir): 190 | # 不存在就创建新的目录 191 | os.makedirs(model_dir) 192 | print(f"Created directory '{model_dir}' for saving models.") 193 | if os.path.isfile(model_path): 194 | try: 195 | maml_model.model.load_state_dict(torch.load(model_path, map_location=device)) 196 | print("Model loaded successfully from '{}'".format(model_path)) 197 | except Exception as e: 198 | print("Failed to load model. Starting from scratch. Error: ", e) 199 | else: 200 | print("No saved model found at '{}'. Starting from scratch.".format(model_path)) 201 | 202 | # 定义训练循环需要的变量 203 | num_meta_iterations = 100 # 元迭代次数 204 | num_tasks = 5 # 任务数量 205 | num_inner_steps = 5 # 每个任务的内部更新步数 206 | num_samples_per_task = 10 # 每个任务采样的样本数 207 | 208 | # 训练循环 209 | for meta_iteration in range(num_meta_iterations): 210 | print('-' * 10, 'Meta Iteration', meta_iteration, '-' * 10) 211 | 212 | # 调用train_maml_vae函数进行训练 213 | train_maml_vae(maml_model, general_loader, specific_loader, num_tasks, num_inner_steps, num_samples_per_task, 214 | meta_iteration) 215 | 216 | # 保存当前的模型参数 217 | vae_model_state_dict = maml_model.model.state_dict() 218 | torch.save(vae_model_state_dict, './saved_model/VAE_cold_start.pth') 219 | -------------------------------------------------------------------------------- /ModelLoader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | import torch 5 | from torch.utils.data import DataLoader, random_split 6 | import matplotlib.pyplot as plt 7 | 8 | 9 | class ModelLoader: 10 | def __init__(self, train_dataset, test_dataset, batch_size, model_path: str, if_early_stop=False, debug_mode=False): 11 | """ 12 | 初始化模型加载器。 13 | 14 | Args: 15 | train_dataset: 用于训练的数据集。 16 | test_dataset: 用于测试的数据集。 17 | batch_size: 每批数据的大小。 18 | model_path: 保存或加载模型权重的路径。 19 | if_early_stop: 是否早停。 20 | debug_mode: 是否为调试模式,调试模式下可能会启用额外的日志或检查点。 21 | """ 22 | self.predict_mode = False 23 | if train_dataset is None and test_dataset is None: 24 | print('Model will run in predict mode.') 25 | self.predict_mode = True 26 | elif train_dataset is not None and test_dataset is not None: 27 | self.train_dataset = train_dataset 28 | self.test_dataset = test_dataset 29 | self.batch_size = batch_size 30 | self.train_iterator = DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, drop_last=True) 31 | self.test_iterator = DataLoader(self.test_dataset, batch_size=self.batch_size, drop_last=True) 32 | elif train_dataset is not None and test_dataset is None: 33 | print('Generate test dataset randomly.') 34 | # 如果没有提供测试集,则从训练集中随机选择一部分作为测试集 35 | test_ratio = 0.01 36 | train_size = int(len(train_dataset) * (1 - test_ratio)) 37 | test_size = len(train_dataset) - train_size 38 | self.batch_size = batch_size 39 | self.train_dataset, self.test_dataset = random_split(train_dataset, [train_size, test_size]) 40 | self.train_iterator = DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, drop_last=True) 41 | self.test_iterator = DataLoader(self.test_dataset, batch_size=self.batch_size, drop_last=True) 42 | 43 | self.if_early_stop = if_early_stop 44 | self.debug_mode = debug_mode 45 | if debug_mode: 46 | os.environ['CUDA_LAUNCH_BLOCKING'] = '1' 47 | torch.autograd.set_detect_anomaly(True) 48 | 49 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 50 | self.model_path = model_path 51 | self.model = None 52 | self.lr = None 53 | self.optimizer = None 54 | self.scheduler = None 55 | 56 | self.best_loss = float('inf') 57 | self.train_losses = [] 58 | self.test_losses = [] 59 | self.current_epoch = 0 60 | 61 | def _train_epoch(self): 62 | """ 63 | 训练模型一个epoch,子类需要根据具体的模型实现该方法。 64 | """ 65 | raise NotImplementedError('Subclasses should implement this method.') 66 | 67 | def _test_epoch(self): 68 | """ 69 | 测试模型一个epoch,子类需要根据具体的模型实现该方法。 70 | """ 71 | raise NotImplementedError('Subclasses should implement this method.') 72 | 73 | def train(self, epochs=50, test_interval=1, figure_interval=10, backup_interval=10): 74 | """ 75 | 训练模型,周期性地在测试集上评估性能。 76 | 77 | Args: 78 | epochs: 训练的总轮次。 79 | test_interval: 测试间隔。 80 | """ 81 | if self.predict_mode: 82 | print('No data given, model is running in predict mode.') 83 | return 84 | print('Start training...') 85 | if self.if_early_stop: 86 | # 早停策略防止过拟合 87 | best_test_loss = float('inf') 88 | patience_counter = 0 89 | 90 | figure_dir = './figure' 91 | os.makedirs(figure_dir, exist_ok=True) 92 | 93 | for epoch in range(1, epochs + 1): 94 | self.current_epoch = epoch 95 | train_loss = self._train_epoch() 96 | test_loss = self._test_epoch() 97 | self.train_losses.append(train_loss) 98 | self.test_losses.append(test_loss) 99 | print(f'Epoch {epoch}/{epochs} - Train Loss: {train_loss:.5f}, Test Loss: {test_loss:.5f}') 100 | 101 | if self.scheduler: 102 | self.scheduler.step() 103 | 104 | if epoch % backup_interval == 0: 105 | self.save_model(test_loss, create_backup=True) 106 | else: 107 | self.save_model(test_loss, create_backup=False) 108 | 109 | if epoch % figure_interval == 0: 110 | # 绘制损失曲线 111 | plt.figure(figsize=(10, 6)) 112 | plt.plot(range(1, epoch + 1), self.train_losses, label='Train Loss') 113 | plt.plot(range(1, epoch + 1), self.test_losses, label='Test Loss') 114 | plt.xlabel('Epoch') 115 | plt.ylabel('Loss') 116 | plt.title(f'Training and Test Loss - Epoch {epoch}') 117 | plt.legend() 118 | plt.grid(True) 119 | timestamp = time.strftime("%Y%m%d-%H%M%S") 120 | plt.savefig( 121 | f'{figure_dir}/{os.path.splitext(os.path.basename(self.model_path))[0]}-{timestamp}-epoch{epoch}.png') 122 | plt.close() 123 | 124 | if epoch % test_interval == 0: 125 | self.test() 126 | 127 | if self.if_early_stop: 128 | # 计算早停累计 129 | if best_test_loss > test_loss: 130 | best_test_loss = test_loss 131 | patience_counter = 1 132 | else: 133 | patience_counter += 1 134 | if patience_counter > max(epochs / 5, 10): 135 | # 早停 136 | print('Training interrupted to avoid overfitting.') 137 | break 138 | 139 | def test(self): 140 | raise NotImplementedError('Subclasses should implement this method.') 141 | 142 | def _save_model(self, save_path: str, is_best: bool = False, is_backup: bool = False): 143 | """ 144 | 保存模型的底层实现。 145 | 146 | Args: 147 | save_path: 保存模型的完整路径 148 | is_best: 是否是最佳模型 149 | is_backup: 是否是备份模型 150 | """ 151 | # 确保保存目录存在 152 | save_dir = os.path.dirname(save_path) 153 | if save_dir and not os.path.exists(save_dir): 154 | os.makedirs(save_dir, exist_ok=True) 155 | print(f'Created directory: {save_dir}') 156 | 157 | # 保存模型 158 | torch.save(self.model.state_dict(), save_path) 159 | 160 | # 根据保存类型打印相应信息 161 | if is_best: 162 | print(f'Best model saved to {save_path}') 163 | elif is_backup: 164 | print(f'Backup model saved to {save_path}') 165 | else: 166 | print(f'Current model saved to {save_path}') 167 | 168 | def save_model(self, loss=float('inf'), create_backup=False): 169 | """ 170 | 保存模型的权重,包括当前模型、最佳模型和可选的备份模型。 171 | 172 | Args: 173 | loss: 当前损失值 174 | create_backup: 是否创建备份模型 175 | """ 176 | model_name, model_extension = os.path.splitext(self.model_path) 177 | 178 | # 保存当前模型 179 | self._save_model(self.model_path) 180 | 181 | # 如果是最佳模型,额外保存一份 182 | if loss < self.best_loss: 183 | self.best_loss = loss 184 | best_model_path = f"{model_name}_best{model_extension}" 185 | self._save_model(best_model_path, is_best=True) 186 | 187 | # 根据参数决定是否创建备份 188 | if create_backup: 189 | backup_model_path = f"{model_name}_backup{model_extension}" 190 | self._save_model(backup_model_path, is_backup=True) 191 | 192 | def load_model(self): 193 | """ 194 | 加载模型的权重。 195 | """ 196 | # load model weight 197 | model_dir = os.path.dirname(self.model_path) 198 | print('Try to load model from', self.model_path) 199 | # 检查模型文件夹路径是否存在 200 | if not os.path.exists(model_dir): 201 | # 不存在就创建新的目录 202 | os.makedirs(model_dir, exist_ok=True) 203 | print(f"Created directory '{model_dir}' for saving models.") 204 | if os.path.isfile(self.model_path): 205 | try: 206 | self.model.load_state_dict(torch.load(self.model_path, map_location=self.device)) 207 | print("Model loaded successfully from '{}'".format(self.model_path)) 208 | except Exception as e: 209 | print("Failed to load model. Starting from scratch. Error: ", e) 210 | else: 211 | print("No saved model found at '{}'. Starting from scratch.".format(self.model_path)) 212 | 213 | def run(self, train_epochs=50): 214 | """ 215 | 执行训练和测试周期。 216 | """ 217 | try: 218 | self.train(train_epochs) 219 | except KeyboardInterrupt: 220 | print('Training interrupted by the user.') 221 | finally: 222 | self.save_model() 223 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VAE-CycleGAN Unsupervised Defect Detection 2 | 3 | This project implements an unsupervised defect detection algorithm based on VAE-CycleGAN for image reconstruction. The algorithm combines the power of Variational Autoencoder (VAE) and CycleGAN to detect defects/anomalies in images without the need for labeled data. 4 | 5 | # VAE-CycleGAN 无监督缺陷检测 6 | 7 | 本项目实现了一种基于 VAE-CycleGAN 的图像重建无监督缺陷检测算法。该算法结合了变分自编码器 (VAE) 和 CycleGAN 的优势,无需标注数据即可检测图像中的缺陷/异常。 8 | 9 | ## Features 10 | 11 | - Unsupervised learning: No need for labeled anomaly data 12 | - VAE-based image reconstruction 13 | - CycleGAN architecture for stable training and local discrimination 14 | - Segment Anything Model (SAM) for automatic image segmentation 15 | - Data augmentation techniques for improved performance 16 | - Use MAML Meta learning to train a general model for avoiding overfitting. 17 | 18 | ## 特点 19 | 20 | - 无监督学习:无需标注异常数据。 21 | - 基于 VAE 的图像重建。 22 | - CycleGAN 架构,实现稳定训练和局部判别。 23 | - Segment Anything Model (SAM) 自动图像分割。 24 | - 数据增强技术,提升存在缺陷时重建的性能。 25 | - 使用元学习得到更好的通用模型,避免过拟合。 26 | 27 | ## Architecture 28 | 29 | The VAE-CycleGAN model consists of two main components: 30 | 31 | 1. **VAE Generator**: A Variational Autoencoder that learns to reconstruct normal images. It consists of an encoder and a decoder network. 32 | 33 | 2. **PatchGAN Discriminator**: A discriminator network from CycleGAN that provides local feedback on the reconstructed images. 34 | 35 | The model is trained to minimize the reconstruction loss between the input images and their reconstructions, as well as the adversarial loss from the discriminator. In our implementation of CycleGAN, we modified the cyclic reconstruction process. Instead of the original cyclic reconstruction, we input the original image A to obtain the reconstructed image B, then input the reconstructed image B to obtain a new reconstructed image C. This modification greatly enhances the model's ability to handle images with anomalies. 36 | 37 | ## 架构 38 | 39 | VAE-CycleGAN 模型由两个主要组件构成: 40 | 41 | 1. **VAE 生成器**:一个变分自编码器,学习重建正常图像。它由编码器和解码器网络组成。 42 | 43 | 2. **PatchGAN 判别器**:来自 CycleGAN 的判别器网络,对重建图像提供局部反馈。 44 | 45 | 该模型通过最小化输入图像与其重建图像之间的重建损失,以及来自判别器的对抗损失来进行训练。在我们的 CycleGAN 实现中,我们修改了循环重建过程。我们不采用原始的循环重建,而是将原始图像 A 输入以获得重建图像 B,然后将重建图像 B 输入以获得新的重建图像 C。这种修改大大增强了模型处理含有异常的图像的能力。 46 | 47 | ## Usage 48 | 49 | 1. Prepare your dataset of normal images. 50 | 2. Use the Segment Anything Model (SAM) in `CutTarget.py` to automatically segment the objects of interest in the images. 51 | 3. Train the VAE-CycleGAN model on the segmented images using the `VAE_GAN_train.py` script. Prepare the training images (normal samples) and test images (samples not seen during training or defective samples) before running the script. 52 | 4. To perform anomaly detection, use the `main.py` script. Modify the input image path and the model paths: one discriminator model for determining the accuracy of SAM segmentation, and the VAE model for reconstruction. The defect results will be saved in the `./output` folder by default. 53 | 54 | ## 使用方法 55 | 56 | 1. 准备正常的PCB图像数据集。 57 | 2. 使用 `CutTarget.py` 文件来调用 Segment Anything Model (SAM) 自动分割图像中感兴趣的对象。 58 | 3. 使用 `VAE_GAN_train.py` 脚本在分割后的图像上训练 VAE-CycleGAN 模型。在运行脚本之前,准备好训练图片(正常样本)和测试图片(模型训练时没有见过的样本或有缺陷样本)。 59 | 4. 进行缺陷检测时,使用 `main.py` 脚本。修改输入图片路径和模型路径:一个辨别器模型用于判断 SAM 的切割内容是否准确,以及用于重建的 VAE 模型。缺陷结果默认保存在 `./output` 文件夹中。 60 | 61 | ## References 62 | 63 | - Variational Autoencoder (VAE) 64 | - CycleGAN 65 | - [Segment Anything Model (SAM)](https://github.com/facebookresearch/segment-anything) 66 | 67 | ## 参考 68 | 69 | - 变分自编码器 (VAE) 70 | - CycleGAN 71 | - [Segment Anything Model (SAM)](https://github.com/facebookresearch/segment-anything) 72 | 73 | ## Effect 74 | 75 | ## 效果演示 76 | 77 | ```Python 78 | The probability of being a real PCB is: 0.2938 79 | Cut PCB from image successfully, take 3.729628086090088 seconds. 80 | VAE input image loaded, take 0.01900482177734375 seconds. 81 | VAE has regenerated the input image, take 0.08401799201965332 seconds. 82 | Dealing with regenerated image... 83 | (1139, 1139, 4) 84 | Structural Similarity (SSIM) Index: 0.856191228429266 85 | Total time cost: 3.981684923171997 seconds. 86 | 相似度: 0.856191228429266 87 | ``` 88 | 89 | ![Effect Demonstration](./resource/effect.png) 90 | -------------------------------------------------------------------------------- /TargetDiscriminator.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | 5 | from ImageTools import * 6 | from GAN import Discriminator 7 | 8 | 9 | class TargetDiscriminator: 10 | def __init__(self, model_path, device='cuda'): 11 | self.device = device 12 | self.model = Discriminator().to(self.device) 13 | self.model.load_state_dict(torch.load(model_path, map_location=self.device)) 14 | self.model.eval() 15 | 16 | self.transform = transforms.Resize((512, 512)) 17 | 18 | def predict(self, image_pcb): 19 | # image = load_image_to_tensor(image_path) 20 | image = convert_image_to_tensor(image_pcb) 21 | image = self.transform(image).unsqueeze(0).to(self.device) 22 | 23 | with torch.no_grad(): 24 | prob = self.model(image).item() 25 | 26 | return prob 27 | 28 | def predict_batch(self, image_path_list): 29 | images = [] 30 | for image_path in image_path_list: 31 | image = load_image_to_tensor(image_path) 32 | image = self.transform(image) 33 | images.append(image) 34 | 35 | images = torch.stack(images).to(self.device) 36 | 37 | with torch.no_grad(): 38 | probs = self.model(images).cpu().numpy() 39 | 40 | return probs 41 | 42 | 43 | if __name__ == '__main__': 44 | discriminator = TargetDiscriminator('saved_model/Discriminator_trained.pth') 45 | 46 | # 读取图片 47 | image_directory = 'dis_test/' 48 | image_paths = [os.path.join(image_directory, filename) for filename in os.listdir(image_directory) if filename.endswith(('.jpg', '.jpeg', '.png'))] 49 | 50 | # 批量图片预测 51 | probs = discriminator.predict_batch(image_paths) 52 | for path, prob in zip(image_paths, probs): 53 | # 获取图片名称 54 | image_name = os.path.basename(path) 55 | print(f'The probability of {image_name} being a real target is: {prob:.4f}') 56 | -------------------------------------------------------------------------------- /VAE_GAN_train.py: -------------------------------------------------------------------------------- 1 | from Datasets import TensorDataset 2 | from GenerateTestDataset import load_test_dataset 3 | from GAN import * 4 | from DataAugment import * 5 | from ImageTools import * 6 | 7 | 8 | def augment_tensor_dataset(tensor_dataset): 9 | augmented_tensors = [] 10 | for tensor in tensor_dataset: 11 | augmented_images = augment_image_tensor(tensor) 12 | # augmented_images = do_nothing(tensor) # 什么增强都不做的函数,只进行resize 13 | for i in range(len(augmented_images)): 14 | augmented_images[i] = limit_size(augmented_images[i]) 15 | augmented_tensors.extend(augmented_images) 16 | return torch.stack(augmented_tensors) # 将列表转换回一个新的张量 17 | 18 | 19 | def load_image_datasets(image_dir): 20 | image_files = os.listdir(image_dir) 21 | images = [] 22 | for image_file in image_files: 23 | image_path = os.path.join(image_dir, image_file) 24 | images.append(load_image_to_tensor(image_path)) 25 | dataset = TensorDataset(images) 26 | return dataset 27 | 28 | 29 | if __name__ == '__main__': 30 | loaded_datasets = [] 31 | """从'./train_images/目录中读取所有图片并加载成数据集用于训练'""" 32 | image_directory = "./train_images/" 33 | # 读取图片并创建数据集 34 | dataset = load_image_datasets(image_directory) 35 | loaded_datasets.append(dataset) 36 | 37 | # 合并数据集 38 | combined_dataset = torch.utils.data.ConcatDataset(loaded_datasets) 39 | # 数据增强 40 | augmented_dataset = augment_tensor_dataset(combined_dataset) 41 | 42 | # 加载测试图片数据集 43 | # 这里的图片测试集可以由GenerateTestDataset.py中的load_test_dataset()得到 44 | test_dataset_path = './datasets/test_dataset.pt' 45 | if os.path.isfile(test_dataset_path): 46 | test_dataset = torch.load(test_dataset_path) 47 | print(f"Test dataset has been loaded from {test_dataset_path}.") 48 | else: 49 | print(f"Error: Test dataset was not found at {test_dataset_path}.") 50 | # 没有的时候就调用函数加载测试图片数据集 51 | test_dataset = load_test_dataset() 52 | 53 | train_epochs = 60 54 | 55 | model = VAEGANModelLoader(augmented_dataset, test_dataset, 10, './saved_model/VAE.pth', 56 | './saved_model/Discriminator.pth') 57 | model.train(train_epochs, 5) 58 | model.test() 59 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | from AutoEncoder import * 4 | from CutTarget import cut_target_from_image, apply_mask 5 | from ImagePreProcessing import * 6 | from ImageTools import * 7 | 8 | 9 | # 计算两个图像之间的结构相似性指数(SSIM) 10 | def calculate_ssim(img1, img2): 11 | C1 = (0.01 * 255) ** 2 12 | C2 = (0.03 * 255) ** 2 13 | img1 = img1.astype(np.float64) 14 | img2 = img2.astype(np.float64) 15 | kernel = cv2.getGaussianKernel(11, 1.5) 16 | window = np.outer(kernel, kernel.transpose()) 17 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] 18 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] 19 | mu1_sq = mu1 ** 2 20 | mu2_sq = mu2 ** 2 21 | mu1_mu2 = mu1 * mu2 22 | sigma1_sq = cv2.filter2D(img1 ** 2, -1, window)[5:-5, 5:-5] - mu1_sq 23 | sigma2_sq = cv2.filter2D(img2 ** 2, -1, window)[5:-5, 5:-5] - mu2_sq 24 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 25 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 26 | return ssim_map.mean() 27 | 28 | 29 | # 对图像进行预处理 30 | def preprocess_image(img): 31 | img = cv2.medianBlur(img, 3) # 中值滤波去噪 32 | return img 33 | 34 | 35 | # 对差异图像进行后处理 36 | def postprocess_diff(diff_img): 37 | # 高阈值二值化 38 | _, high_thresh_bin = cv2.threshold(diff_img, 200, 255, cv2.THRESH_BINARY) # 调整高阈值 39 | 40 | # 低阈值二值化 41 | _, diff_bin = cv2.threshold(diff_img, 45, 255, cv2.THRESH_BINARY) # 调整低阈值 42 | 43 | # 形态学操作 44 | diff_bin = cv2.dilate(diff_bin, np.ones((3, 3)), iterations=1) # 膨胀操作 45 | diff_bin = cv2.erode(diff_bin, np.ones((13, 13)), iterations=1) # 腐蚀操作 46 | diff_bin = cv2.dilate(diff_bin, np.ones((11, 11)), iterations=1) # 膨胀操作 47 | 48 | # 连通域分析 49 | num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(diff_bin, connectivity=8) 50 | 51 | # 创建一个新的掩码用于存储处理后的结果 52 | processed_diff_bin = np.zeros_like(diff_bin) 53 | 54 | # 遍历每个连通域 55 | for i in range(1, num_labels): 56 | # 获取当前连通域的掩码 57 | component_mask = (labels == i).astype(np.uint8) 58 | 59 | # 检查当前连通域在高阈值二值化图像中是否存在非零像素 60 | if cv2.countNonZero(high_thresh_bin & component_mask) > 0: 61 | # 如果存在非零像素,则保留该连通域 62 | processed_diff_bin |= component_mask 63 | return processed_diff_bin 64 | 65 | 66 | # 将缺陷图映射到原图中去 67 | def resize_defection_map_to_origin_size(defection_map, mask_image, coordinates, info, original_size): 68 | origin_cropped_defection_map = cv2.resize(defection_map, mask_image.size) 69 | x_min, x_max, y_min, y_max = coordinates 70 | width, height = original_size 71 | 72 | # 创建一个和原始大图相同大小的空白图像 73 | origin_defection_map = np.ones((height, width), dtype=np.uint8) * 0 # 填充为黑色 74 | 75 | # 将小图像放置到原始大图中指定位置 76 | origin_defection_map[y_min:y_max + 1, x_min:x_max + 1] = origin_cropped_defection_map 77 | return origin_defection_map 78 | 79 | 80 | def defection_detection(file_path, VAE_model=None, reason_mode='cuda'): 81 | image = Image.open(file_path) 82 | print('prepare to cut PCB from the image...') 83 | t0 = time.time() 84 | mask_image, coordinates, mask, info = cut_target_from_image(image, reason_mode=reason_mode) 85 | x_min, x_max, y_min, y_max = coordinates 86 | t1 = time.time() 87 | print(f'Cut PCB from image successfully, take {t1 - t0} seconds.') 88 | device = reason_mode 89 | VAE_model.to(device) 90 | image_tensor = convert_image_to_tensor(mask_image) 91 | 92 | size_transform = transforms.Resize((512, 512)) # 将图像调整为512x512 93 | # 应用预处理转换链 94 | input_tensor = size_transform(image_tensor) 95 | 96 | # 添加批次维度并将图像输入模型 97 | input_tensor = input_tensor.to(device).unsqueeze(0) # 添加批次维度,即从C x H x W变为1 x C x H x W 98 | 99 | t2 = time.time() 100 | print(f'VAE input image loaded, take {t2 - t1} seconds.') 101 | with torch.no_grad(): 102 | # 编码图像,获取潜在空间的均值 103 | z_mu, _ = VAE_model.encoder(input_tensor) 104 | # 解码 105 | regenerated_image = VAE_model.decoder(z_mu) 106 | 107 | regenerated_img = regenerated_image.cpu().squeeze(0) # 从batch中移除,得到inputchannel x 512 x 512的图片 108 | regenerated_img = regenerated_img.permute(1, 2, 0) # 调整为512x512x input_channel 109 | regenerated_img = regenerated_img.squeeze() # 去除单一通道维度 110 | # 对重建图像进行归一化处理 111 | regenerated_img = regenerated_img.clamp(0, 1) # 将像素值限制在[0, 1]范围内 112 | t3 = time.time() 113 | print(f'VAE has regenerated the input image, take {t3 - t2} seconds.') 114 | 115 | print('Dealing with regenerated image...') 116 | # 将归一化后的图像乘以255,转换为[0, 255]范围内的整数值 117 | regenerated_img_np = regenerated_img.numpy() 118 | regenerated_img_np = (regenerated_img_np * 255).astype(np.uint8) 119 | regenerated_img_np_origin = cv2.resize(regenerated_img_np, (mask.shape[1], mask.shape[0])) 120 | rgb_image = cv2.cvtColor(regenerated_img_np_origin, cv2.COLOR_GRAY2RGB) 121 | gray_image, mask = apply_mask(rgb_image, mask) 122 | regenerated_img_np = cv2.cvtColor(gray_image, cv2.COLOR_RGB2GRAY) 123 | regenerated_img_np = cv2.resize(regenerated_img_np, (512, 512)) 124 | 125 | # 读取并预处理原始图像和重建图像 126 | ori_image = load_image_with_alpha_channel(file_path) 127 | original_image = ori_image.copy() 128 | original_size = original_image.size 129 | original_image = np.array(original_image) 130 | original_image_crop, mask = apply_mask(original_image[y_min:y_max + 1, x_min:x_max + 1], mask) 131 | print(original_image_crop.shape) 132 | original_image = cv2.cvtColor(original_image_crop, cv2.COLOR_RGB2GRAY) 133 | original_image = preprocess_image(original_image) 134 | reconstructed_img = regenerated_img_np 135 | reconstructed_img = preprocess_image(reconstructed_img) 136 | 137 | # 调整大小并计算SSIM 138 | original_image = cv2.resize(original_image, (512, 512)) 139 | reconstructed_img = cv2.resize(reconstructed_img, (512, 512)) 140 | ssim = calculate_ssim(original_image, reconstructed_img) 141 | print("Structural Similarity (SSIM) Index: ", ssim) 142 | 143 | # 计算差异图像并进行后处理 144 | diff_img = np.abs(original_image.astype(np.float32) - reconstructed_img.astype(np.float32)) 145 | diff_img = (diff_img * 255.0 / diff_img.max()).astype(np.uint8) 146 | diff_bin = postprocess_diff(diff_img) 147 | 148 | diff_bin = resize_defection_map_to_origin_size(diff_bin, mask_image, coordinates, info, original_size) 149 | print(f'Total time cost: {time.time() - t0} seconds.') 150 | 151 | # 显示结果 152 | plt.figure(figsize=(20, 5)) 153 | plt.subplot(1, 4, 1) 154 | plt.imshow(ori_image, cmap='gray') 155 | plt.title('Original Image') 156 | plt.axis('off') 157 | plt.subplot(1, 4, 2) 158 | plt.imshow(reconstructed_img, cmap='gray') 159 | plt.title('Reconstructed Image') 160 | plt.axis('off') 161 | plt.subplot(1, 4, 3) 162 | plt.imshow(diff_img, cmap='jet') 163 | plt.title('Difference Image') 164 | plt.axis('off') 165 | plt.subplot(1, 4, 4) 166 | plt.imshow(diff_bin, cmap='gray') 167 | plt.title('Defect Regions') 168 | plt.axis('off') 169 | plt.tight_layout() 170 | plt.show() 171 | 172 | return ssim, diff_bin 173 | 174 | 175 | if __name__ == "__main__": 176 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 177 | # 初始化模型 178 | encoder = VAEEncoder(latent_dim) 179 | # decoder 180 | decoder = VAEDecoder(latent_dim) 181 | # VAE 182 | VAE_model = VAEModel(encoder, decoder).to(device) 183 | model_path = './saved_model/VAE.pth' 184 | VAE_model.load_state_dict(torch.load(model_path, map_location=device)) 185 | print("Model loaded successfully from '{}'".format(model_path)) 186 | 187 | input_dir = r'input' 188 | output_dir = r'output' 189 | image_files = [f for f in os.listdir(input_dir) if 190 | f.lower().endswith(('.png', '.jpg', '.jpeg', '.JPG', '.JPEG', '.PNG', 'bmp'))] 191 | files_num = len(image_files) 192 | 193 | # 创建一个空列表用于存储图片对应关系 194 | image_pairs = [] 195 | 196 | for i in range(files_num): 197 | filename = image_files[i] 198 | print("第{}张图:".format(i + 1), filename) 199 | 200 | # 判断文件名是否包含'bin',如果包含则跳过 201 | if 'bin' in filename or '_t.' in filename: 202 | print("跳过文件: {}".format(filename)) 203 | continue 204 | 205 | score, diff_bin = defection_detection(os.path.join(input_dir, filename), VAE_model=VAE_model, 206 | reason_mode=str(device)) 207 | 208 | # 打印返回的分数 209 | print("相似度:", score) 210 | 211 | # 对二值化图像进行反色处理 212 | diff_bin = cv2.bitwise_not(diff_bin) 213 | 214 | # 将二值化后的图像保存到输出路径 215 | filename_without_ext = os.path.splitext(filename)[0] 216 | # 构建新的文件名 217 | new_filename = f"{filename_without_ext}_bin.png" 218 | output_path = os.path.join(output_dir, new_filename) 219 | cv2.imwrite(output_path, diff_bin) 220 | 221 | # 将输入图片文件名和输出图片文件名加入到图片对应关系列表中 222 | image_pairs.append(f"{filename} {new_filename}") 223 | 224 | # 将图片对应关系写入txt文件 225 | txt_file = "./output/mark.txt" 226 | 227 | # 检查文件是否存在 228 | if os.path.exists(txt_file): 229 | # 如果文件存在,则删除文件 230 | os.remove(txt_file) 231 | 232 | # 将新的图片对应关系写入txt文件 233 | with open(txt_file, "w") as file: 234 | file.write("\n".join(image_pairs)) 235 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | datasets==2.14.7 2 | opencv-python==4.8.1.78 3 | torch==2.1.0+cu121 4 | torchvision==0.16.0+cu121 5 | 6 | numpy~=1.26.4 7 | pillow~=9.3.0 8 | matplotlib~=3.8.0 -------------------------------------------------------------------------------- /resource/effect.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DEVILENMO/Unsupervised-Defect-Detection-Project-Based-on-VAE-GAN-Architecture/05e4d41edd4ee49b3a7b3e5ab399031c505c65f6/resource/effect.png -------------------------------------------------------------------------------- /segment_anything/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .build_sam import ( 8 | build_sam, 9 | build_sam_vit_h, 10 | build_sam_vit_l, 11 | build_sam_vit_b, 12 | sam_model_registry, 13 | ) 14 | from .predictor import SamPredictor 15 | from .automatic_mask_generator import SamAutomaticMaskGenerator 16 | -------------------------------------------------------------------------------- /segment_anything/automatic_mask_generator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torchvision.ops.boxes import batched_nms, box_area # type: ignore 10 | 11 | from typing import Any, Dict, List, Optional, Tuple 12 | 13 | from .modeling import Sam 14 | from .predictor import SamPredictor 15 | from .utils.amg import ( 16 | MaskData, 17 | area_from_rle, 18 | batch_iterator, 19 | batched_mask_to_box, 20 | box_xyxy_to_xywh, 21 | build_all_layer_point_grids, 22 | calculate_stability_score, 23 | coco_encode_rle, 24 | generate_crop_boxes, 25 | is_box_near_crop_edge, 26 | mask_to_rle_pytorch, 27 | remove_small_regions, 28 | rle_to_mask, 29 | uncrop_boxes_xyxy, 30 | uncrop_masks, 31 | uncrop_points, 32 | ) 33 | 34 | 35 | class SamAutomaticMaskGenerator: 36 | def __init__( 37 | self, 38 | model: Sam, 39 | points_per_side: Optional[int] = 32, 40 | points_per_batch: int = 64, 41 | pred_iou_thresh: float = 0.88, 42 | stability_score_thresh: float = 0.95, 43 | stability_score_offset: float = 1.0, 44 | box_nms_thresh: float = 0.7, 45 | crop_n_layers: int = 0, 46 | crop_nms_thresh: float = 0.7, 47 | crop_overlap_ratio: float = 512 / 1500, 48 | crop_n_points_downscale_factor: int = 1, 49 | point_grids: Optional[List[np.ndarray]] = None, 50 | min_mask_region_area: int = 0, 51 | output_mode: str = "binary_mask", 52 | ) -> None: 53 | """ 54 | Using a SAM model, generates masks for the entire image. 55 | Generates a grid of point prompts over the image, then filters 56 | low quality and duplicate masks. The default settings are chosen 57 | for SAM with a ViT-H backbone. 58 | 59 | Arguments: 60 | model (Sam): The SAM model to use for mask prediction. 61 | points_per_side (int or None): The number of points to be sampled 62 | along one side of the image. The total number of points is 63 | points_per_side**2. If None, 'point_grids' must provide explicit 64 | point sampling. 65 | points_per_batch (int): Sets the number of points run simultaneously 66 | by the model. Higher numbers may be faster but use more GPU memory. 67 | pred_iou_thresh (float): A filtering threshold in [0,1], using the 68 | model's predicted mask quality. 69 | stability_score_thresh (float): A filtering threshold in [0,1], using 70 | the stability of the mask under changes to the cutoff used to binarize 71 | the model's mask predictions. 72 | stability_score_offset (float): The amount to shift the cutoff when 73 | calculated the stability score. 74 | box_nms_thresh (float): The box IoU cutoff used by non-maximal 75 | suppression to filter duplicate masks. 76 | crop_n_layers (int): If >0, mask prediction will be run again on 77 | crops of the image. Sets the number of layers to run, where each 78 | layer has 2**i_layer number of image crops. 79 | crop_nms_thresh (float): The box IoU cutoff used by non-maximal 80 | suppression to filter duplicate masks between different crops. 81 | crop_overlap_ratio (float): Sets the degree to which crops overlap. 82 | In the first crop layer, crops will overlap by this fraction of 83 | the image length. Later layers with more crops scale down this overlap. 84 | crop_n_points_downscale_factor (int): The number of points-per-side 85 | sampled in layer n is scaled down by crop_n_points_downscale_factor**n. 86 | point_grids (list(np.ndarray) or None): A list over explicit grids 87 | of points used for sampling, normalized to [0,1]. The nth grid in the 88 | list is used in the nth crop layer. Exclusive with points_per_side. 89 | min_mask_region_area (int): If >0, postprocessing will be applied 90 | to remove disconnected regions and holes in masks with area smaller 91 | than min_mask_region_area. Requires opencv. 92 | output_mode (str): The form masks are returned in. Can be 'binary_mask', 93 | 'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools. 94 | For large resolutions, 'binary_mask' may consume large amounts of 95 | memory. 96 | """ 97 | 98 | assert (points_per_side is None) != ( 99 | point_grids is None 100 | ), "Exactly one of points_per_side or point_grid must be provided." 101 | if points_per_side is not None: 102 | self.point_grids = build_all_layer_point_grids( 103 | points_per_side, 104 | crop_n_layers, 105 | crop_n_points_downscale_factor, 106 | ) 107 | elif point_grids is not None: 108 | self.point_grids = point_grids 109 | else: 110 | raise ValueError("Can't have both points_per_side and point_grid be None.") 111 | 112 | assert output_mode in [ 113 | "binary_mask", 114 | "uncompressed_rle", 115 | "coco_rle", 116 | ], f"Unknown output_mode {output_mode}." 117 | if output_mode == "coco_rle": 118 | from pycocotools import mask as mask_utils # type: ignore # noqa: F401 119 | 120 | if min_mask_region_area > 0: 121 | import cv2 # type: ignore # noqa: F401 122 | 123 | self.predictor = SamPredictor(model) 124 | self.points_per_batch = points_per_batch 125 | self.pred_iou_thresh = pred_iou_thresh 126 | self.stability_score_thresh = stability_score_thresh 127 | self.stability_score_offset = stability_score_offset 128 | self.box_nms_thresh = box_nms_thresh 129 | self.crop_n_layers = crop_n_layers 130 | self.crop_nms_thresh = crop_nms_thresh 131 | self.crop_overlap_ratio = crop_overlap_ratio 132 | self.crop_n_points_downscale_factor = crop_n_points_downscale_factor 133 | self.min_mask_region_area = min_mask_region_area 134 | self.output_mode = output_mode 135 | 136 | @torch.no_grad() 137 | def generate(self, image: np.ndarray) -> List[Dict[str, Any]]: 138 | """ 139 | Generates masks for the given image. 140 | 141 | Arguments: 142 | image (np.ndarray): The image to generate masks for, in HWC uint8 format. 143 | 144 | Returns: 145 | list(dict(str, any)): A list over records for masks. Each record is 146 | a dict containing the following keys: 147 | segmentation (dict(str, any) or np.ndarray): The mask. If 148 | output_mode='binary_mask', is an array of shape HW. Otherwise, 149 | is a dictionary containing the RLE. 150 | bbox (list(float)): The box around the mask, in XYWH format. 151 | area (int): The area in pixels of the mask. 152 | predicted_iou (float): The model's own prediction of the mask's 153 | quality. This is filtered by the pred_iou_thresh parameter. 154 | point_coords (list(list(float))): The point coordinates input 155 | to the model to generate this mask. 156 | stability_score (float): A measure of the mask's quality. This 157 | is filtered on using the stability_score_thresh parameter. 158 | crop_box (list(float)): The crop of the image used to generate 159 | the mask, given in XYWH format. 160 | """ 161 | 162 | # Generate masks 163 | mask_data = self._generate_masks(image) 164 | 165 | # Filter small disconnected regions and holes in masks 166 | if self.min_mask_region_area > 0: 167 | mask_data = self.postprocess_small_regions( 168 | mask_data, 169 | self.min_mask_region_area, 170 | max(self.box_nms_thresh, self.crop_nms_thresh), 171 | ) 172 | 173 | # Encode masks 174 | if self.output_mode == "coco_rle": 175 | mask_data["segmentations"] = [coco_encode_rle(rle) for rle in mask_data["rles"]] 176 | elif self.output_mode == "binary_mask": 177 | mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]] 178 | else: 179 | mask_data["segmentations"] = mask_data["rles"] 180 | 181 | # Write mask records 182 | curr_anns = [] 183 | for idx in range(len(mask_data["segmentations"])): 184 | ann = { 185 | "segmentation": mask_data["segmentations"][idx], 186 | "area": area_from_rle(mask_data["rles"][idx]), 187 | "bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(), 188 | "predicted_iou": mask_data["iou_preds"][idx].item(), 189 | "point_coords": [mask_data["points"][idx].tolist()], 190 | "stability_score": mask_data["stability_score"][idx].item(), 191 | "crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(), 192 | } 193 | curr_anns.append(ann) 194 | 195 | return curr_anns 196 | 197 | def _generate_masks(self, image: np.ndarray) -> MaskData: 198 | orig_size = image.shape[:2] 199 | crop_boxes, layer_idxs = generate_crop_boxes( 200 | orig_size, self.crop_n_layers, self.crop_overlap_ratio 201 | ) 202 | 203 | # Iterate over image crops 204 | data = MaskData() 205 | for crop_box, layer_idx in zip(crop_boxes, layer_idxs): 206 | crop_data = self._process_crop(image, crop_box, layer_idx, orig_size) 207 | data.cat(crop_data) 208 | 209 | # Remove duplicate masks between crops 210 | if len(crop_boxes) > 1: 211 | # Prefer masks from smaller crops 212 | scores = 1 / box_area(data["crop_boxes"]) 213 | scores = scores.to(data["boxes"].device) 214 | keep_by_nms = batched_nms( 215 | data["boxes"].float(), 216 | scores, 217 | torch.zeros_like(data["boxes"][:, 0]), # categories 218 | iou_threshold=self.crop_nms_thresh, 219 | ) 220 | data.filter(keep_by_nms) 221 | 222 | data.to_numpy() 223 | return data 224 | 225 | def _process_crop( 226 | self, 227 | image: np.ndarray, 228 | crop_box: List[int], 229 | crop_layer_idx: int, 230 | orig_size: Tuple[int, ...], 231 | ) -> MaskData: 232 | # Crop the image and calculate embeddings 233 | x0, y0, x1, y1 = crop_box 234 | cropped_im = image[y0:y1, x0:x1, :] 235 | cropped_im_size = cropped_im.shape[:2] 236 | self.predictor.set_image(cropped_im) 237 | 238 | # Get points for this crop 239 | points_scale = np.array(cropped_im_size)[None, ::-1] 240 | points_for_image = self.point_grids[crop_layer_idx] * points_scale 241 | 242 | # Generate masks for this crop in batches 243 | data = MaskData() 244 | for (points,) in batch_iterator(self.points_per_batch, points_for_image): 245 | batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size) 246 | data.cat(batch_data) 247 | del batch_data 248 | self.predictor.reset_image() 249 | 250 | # Remove duplicates within this crop. 251 | keep_by_nms = batched_nms( 252 | data["boxes"].float(), 253 | data["iou_preds"], 254 | torch.zeros_like(data["boxes"][:, 0]), # categories 255 | iou_threshold=self.box_nms_thresh, 256 | ) 257 | data.filter(keep_by_nms) 258 | 259 | # Return to the original image frame 260 | data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box) 261 | data["points"] = uncrop_points(data["points"], crop_box) 262 | data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))]) 263 | 264 | return data 265 | 266 | def _process_batch( 267 | self, 268 | points: np.ndarray, 269 | im_size: Tuple[int, ...], 270 | crop_box: List[int], 271 | orig_size: Tuple[int, ...], 272 | ) -> MaskData: 273 | orig_h, orig_w = orig_size 274 | 275 | # Run model on this batch 276 | transformed_points = self.predictor.transform.apply_coords(points, im_size) 277 | in_points = torch.as_tensor(transformed_points, device=self.predictor.device) 278 | in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device) 279 | masks, iou_preds, _ = self.predictor.predict_torch( 280 | in_points[:, None, :], 281 | in_labels[:, None], 282 | multimask_output=True, 283 | return_logits=True, 284 | ) 285 | 286 | # Serialize predictions and store in MaskData 287 | data = MaskData( 288 | masks=masks.flatten(0, 1), 289 | iou_preds=iou_preds.flatten(0, 1), 290 | points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)), 291 | ) 292 | del masks 293 | 294 | # Filter by predicted IoU 295 | if self.pred_iou_thresh > 0.0: 296 | keep_mask = data["iou_preds"] > self.pred_iou_thresh 297 | data.filter(keep_mask) 298 | 299 | # Calculate stability score 300 | data["stability_score"] = calculate_stability_score( 301 | data["masks"], self.predictor.model.mask_threshold, self.stability_score_offset 302 | ) 303 | if self.stability_score_thresh > 0.0: 304 | keep_mask = data["stability_score"] >= self.stability_score_thresh 305 | data.filter(keep_mask) 306 | 307 | # Threshold masks and calculate boxes 308 | data["masks"] = data["masks"] > self.predictor.model.mask_threshold 309 | data["boxes"] = batched_mask_to_box(data["masks"]) 310 | 311 | # Filter boxes that touch crop boundaries 312 | keep_mask = ~is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h]) 313 | if not torch.all(keep_mask): 314 | data.filter(keep_mask) 315 | 316 | # Compress to RLE 317 | data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w) 318 | data["rles"] = mask_to_rle_pytorch(data["masks"]) 319 | del data["masks"] 320 | 321 | return data 322 | 323 | @staticmethod 324 | def postprocess_small_regions( 325 | mask_data: MaskData, min_area: int, nms_thresh: float 326 | ) -> MaskData: 327 | """ 328 | Removes small disconnected regions and holes in masks, then reruns 329 | box NMS to remove any new duplicates. 330 | 331 | Edits mask_data in place. 332 | 333 | Requires open-cv as a dependency. 334 | """ 335 | if len(mask_data["rles"]) == 0: 336 | return mask_data 337 | 338 | # Filter small disconnected regions and holes 339 | new_masks = [] 340 | scores = [] 341 | for rle in mask_data["rles"]: 342 | mask = rle_to_mask(rle) 343 | 344 | mask, changed = remove_small_regions(mask, min_area, mode="holes") 345 | unchanged = not changed 346 | mask, changed = remove_small_regions(mask, min_area, mode="islands") 347 | unchanged = unchanged and not changed 348 | 349 | new_masks.append(torch.as_tensor(mask).unsqueeze(0)) 350 | # Give score=0 to changed masks and score=1 to unchanged masks 351 | # so NMS will prefer ones that didn't need postprocessing 352 | scores.append(float(unchanged)) 353 | 354 | # Recalculate boxes and remove any new duplicates 355 | masks = torch.cat(new_masks, dim=0) 356 | boxes = batched_mask_to_box(masks) 357 | keep_by_nms = batched_nms( 358 | boxes.float(), 359 | torch.as_tensor(scores), 360 | torch.zeros_like(boxes[:, 0]), # categories 361 | iou_threshold=nms_thresh, 362 | ) 363 | 364 | # Only recalculate RLEs for masks that have changed 365 | for i_mask in keep_by_nms: 366 | if scores[i_mask] == 0.0: 367 | mask_torch = masks[i_mask].unsqueeze(0) 368 | mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0] 369 | mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly 370 | mask_data.filter(keep_by_nms) 371 | 372 | return mask_data 373 | -------------------------------------------------------------------------------- /segment_anything/build_sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | 9 | from functools import partial 10 | 11 | from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer 12 | 13 | 14 | def build_sam_vit_h(checkpoint=None): 15 | return _build_sam( 16 | encoder_embed_dim=1280, 17 | encoder_depth=32, 18 | encoder_num_heads=16, 19 | encoder_global_attn_indexes=[7, 15, 23, 31], 20 | checkpoint=checkpoint, 21 | ) 22 | 23 | 24 | build_sam = build_sam_vit_h 25 | 26 | 27 | def build_sam_vit_l(checkpoint=None): 28 | return _build_sam( 29 | encoder_embed_dim=1024, 30 | encoder_depth=24, 31 | encoder_num_heads=16, 32 | encoder_global_attn_indexes=[5, 11, 17, 23], 33 | checkpoint=checkpoint, 34 | ) 35 | 36 | 37 | def build_sam_vit_b(checkpoint=None): 38 | return _build_sam( 39 | encoder_embed_dim=768, 40 | encoder_depth=12, 41 | encoder_num_heads=12, 42 | encoder_global_attn_indexes=[2, 5, 8, 11], 43 | checkpoint=checkpoint, 44 | ) 45 | 46 | 47 | sam_model_registry = { 48 | "default": build_sam_vit_h, 49 | "vit_h": build_sam_vit_h, 50 | "vit_l": build_sam_vit_l, 51 | "vit_b": build_sam_vit_b, 52 | } 53 | 54 | 55 | def _build_sam( 56 | encoder_embed_dim, 57 | encoder_depth, 58 | encoder_num_heads, 59 | encoder_global_attn_indexes, 60 | checkpoint=None, 61 | ): 62 | prompt_embed_dim = 256 63 | image_size = 1024 64 | vit_patch_size = 16 65 | image_embedding_size = image_size // vit_patch_size 66 | sam = Sam( 67 | image_encoder=ImageEncoderViT( 68 | depth=encoder_depth, 69 | embed_dim=encoder_embed_dim, 70 | img_size=image_size, 71 | mlp_ratio=4, 72 | norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 73 | num_heads=encoder_num_heads, 74 | patch_size=vit_patch_size, 75 | qkv_bias=True, 76 | use_rel_pos=True, 77 | global_attn_indexes=encoder_global_attn_indexes, 78 | window_size=14, 79 | out_chans=prompt_embed_dim, 80 | ), 81 | prompt_encoder=PromptEncoder( 82 | embed_dim=prompt_embed_dim, 83 | image_embedding_size=(image_embedding_size, image_embedding_size), 84 | input_image_size=(image_size, image_size), 85 | mask_in_chans=16, 86 | ), 87 | mask_decoder=MaskDecoder( 88 | num_multimask_outputs=3, 89 | transformer=TwoWayTransformer( 90 | depth=2, 91 | embedding_dim=prompt_embed_dim, 92 | mlp_dim=2048, 93 | num_heads=8, 94 | ), 95 | transformer_dim=prompt_embed_dim, 96 | iou_head_depth=3, 97 | iou_head_hidden_dim=256, 98 | ), 99 | pixel_mean=[123.675, 116.28, 103.53], 100 | pixel_std=[58.395, 57.12, 57.375], 101 | ) 102 | sam.eval() 103 | if checkpoint is not None: 104 | with open(checkpoint, "rb") as f: 105 | state_dict = torch.load(f) 106 | sam.load_state_dict(state_dict) 107 | return sam 108 | -------------------------------------------------------------------------------- /segment_anything/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .sam import Sam 8 | from .image_encoder import ImageEncoderViT 9 | from .mask_decoder import MaskDecoder 10 | from .prompt_encoder import PromptEncoder 11 | from .transformer import TwoWayTransformer 12 | -------------------------------------------------------------------------------- /segment_anything/modeling/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from typing import Type 11 | 12 | 13 | class MLPBlock(nn.Module): 14 | def __init__( 15 | self, 16 | embedding_dim: int, 17 | mlp_dim: int, 18 | act: Type[nn.Module] = nn.GELU, 19 | ) -> None: 20 | super().__init__() 21 | self.lin1 = nn.Linear(embedding_dim, mlp_dim) 22 | self.lin2 = nn.Linear(mlp_dim, embedding_dim) 23 | self.act = act() 24 | 25 | def forward(self, x: torch.Tensor) -> torch.Tensor: 26 | return self.lin2(self.act(self.lin1(x))) 27 | 28 | 29 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa 30 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa 31 | class LayerNorm2d(nn.Module): 32 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None: 33 | super().__init__() 34 | self.weight = nn.Parameter(torch.ones(num_channels)) 35 | self.bias = nn.Parameter(torch.zeros(num_channels)) 36 | self.eps = eps 37 | 38 | def forward(self, x: torch.Tensor) -> torch.Tensor: 39 | u = x.mean(1, keepdim=True) 40 | s = (x - u).pow(2).mean(1, keepdim=True) 41 | x = (x - u) / torch.sqrt(s + self.eps) 42 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 43 | return x 44 | -------------------------------------------------------------------------------- /segment_anything/modeling/image_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from typing import Optional, Tuple, Type 12 | 13 | from .common import LayerNorm2d, MLPBlock 14 | 15 | 16 | # This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa 17 | class ImageEncoderViT(nn.Module): 18 | def __init__( 19 | self, 20 | img_size: int = 1024, 21 | patch_size: int = 16, 22 | in_chans: int = 3, 23 | embed_dim: int = 768, 24 | depth: int = 12, 25 | num_heads: int = 12, 26 | mlp_ratio: float = 4.0, 27 | out_chans: int = 256, 28 | qkv_bias: bool = True, 29 | norm_layer: Type[nn.Module] = nn.LayerNorm, 30 | act_layer: Type[nn.Module] = nn.GELU, 31 | use_abs_pos: bool = True, 32 | use_rel_pos: bool = False, 33 | rel_pos_zero_init: bool = True, 34 | window_size: int = 0, 35 | global_attn_indexes: Tuple[int, ...] = (), 36 | ) -> None: 37 | """ 38 | Args: 39 | img_size (int): Input image size. 40 | patch_size (int): Patch size. 41 | in_chans (int): Number of input image channels. 42 | embed_dim (int): Patch embedding dimension. 43 | depth (int): Depth of ViT. 44 | num_heads (int): Number of attention heads in each ViT block. 45 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 46 | qkv_bias (bool): If True, add a learnable bias to query, key, value. 47 | norm_layer (nn.Module): Normalization layer. 48 | act_layer (nn.Module): Activation layer. 49 | use_abs_pos (bool): If True, use absolute positional embeddings. 50 | use_rel_pos (bool): If True, add relative positional embeddings to the attention map. 51 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. 52 | window_size (int): Window size for window attention blocks. 53 | global_attn_indexes (list): Indexes for blocks using global attention. 54 | """ 55 | super().__init__() 56 | self.img_size = img_size 57 | 58 | self.patch_embed = PatchEmbed( 59 | kernel_size=(patch_size, patch_size), 60 | stride=(patch_size, patch_size), 61 | in_chans=in_chans, 62 | embed_dim=embed_dim, 63 | ) 64 | 65 | self.pos_embed: Optional[nn.Parameter] = None 66 | if use_abs_pos: 67 | # Initialize absolute positional embedding with pretrain image size. 68 | self.pos_embed = nn.Parameter( 69 | torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim) 70 | ) 71 | 72 | self.blocks = nn.ModuleList() 73 | for i in range(depth): 74 | block = Block( 75 | dim=embed_dim, 76 | num_heads=num_heads, 77 | mlp_ratio=mlp_ratio, 78 | qkv_bias=qkv_bias, 79 | norm_layer=norm_layer, 80 | act_layer=act_layer, 81 | use_rel_pos=use_rel_pos, 82 | rel_pos_zero_init=rel_pos_zero_init, 83 | window_size=window_size if i not in global_attn_indexes else 0, 84 | input_size=(img_size // patch_size, img_size // patch_size), 85 | ) 86 | self.blocks.append(block) 87 | 88 | self.neck = nn.Sequential( 89 | nn.Conv2d( 90 | embed_dim, 91 | out_chans, 92 | kernel_size=1, 93 | bias=False, 94 | ), 95 | LayerNorm2d(out_chans), 96 | nn.Conv2d( 97 | out_chans, 98 | out_chans, 99 | kernel_size=3, 100 | padding=1, 101 | bias=False, 102 | ), 103 | LayerNorm2d(out_chans), 104 | ) 105 | 106 | def forward(self, x: torch.Tensor) -> torch.Tensor: 107 | x = self.patch_embed(x) 108 | if self.pos_embed is not None: 109 | x = x + self.pos_embed 110 | 111 | for blk in self.blocks: 112 | x = blk(x) 113 | 114 | x = self.neck(x.permute(0, 3, 1, 2)) 115 | 116 | return x 117 | 118 | 119 | class Block(nn.Module): 120 | """Transformer blocks with support of window attention and residual propagation blocks""" 121 | 122 | def __init__( 123 | self, 124 | dim: int, 125 | num_heads: int, 126 | mlp_ratio: float = 4.0, 127 | qkv_bias: bool = True, 128 | norm_layer: Type[nn.Module] = nn.LayerNorm, 129 | act_layer: Type[nn.Module] = nn.GELU, 130 | use_rel_pos: bool = False, 131 | rel_pos_zero_init: bool = True, 132 | window_size: int = 0, 133 | input_size: Optional[Tuple[int, int]] = None, 134 | ) -> None: 135 | """ 136 | Args: 137 | dim (int): Number of input channels. 138 | num_heads (int): Number of attention heads in each ViT block. 139 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 140 | qkv_bias (bool): If True, add a learnable bias to query, key, value. 141 | norm_layer (nn.Module): Normalization layer. 142 | act_layer (nn.Module): Activation layer. 143 | use_rel_pos (bool): If True, add relative positional embeddings to the attention map. 144 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. 145 | window_size (int): Window size for window attention blocks. If it equals 0, then 146 | use global attention. 147 | input_size (tuple(int, int) or None): Input resolution for calculating the relative 148 | positional parameter size. 149 | """ 150 | super().__init__() 151 | self.norm1 = norm_layer(dim) 152 | self.attn = Attention( 153 | dim, 154 | num_heads=num_heads, 155 | qkv_bias=qkv_bias, 156 | use_rel_pos=use_rel_pos, 157 | rel_pos_zero_init=rel_pos_zero_init, 158 | input_size=input_size if window_size == 0 else (window_size, window_size), 159 | ) 160 | 161 | self.norm2 = norm_layer(dim) 162 | self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer) 163 | 164 | self.window_size = window_size 165 | 166 | def forward(self, x: torch.Tensor) -> torch.Tensor: 167 | shortcut = x 168 | x = self.norm1(x) 169 | # Window partition 170 | if self.window_size > 0: 171 | H, W = x.shape[1], x.shape[2] 172 | x, pad_hw = window_partition(x, self.window_size) 173 | 174 | x = self.attn(x) 175 | # Reverse window partition 176 | if self.window_size > 0: 177 | x = window_unpartition(x, self.window_size, pad_hw, (H, W)) 178 | 179 | x = shortcut + x 180 | x = x + self.mlp(self.norm2(x)) 181 | 182 | return x 183 | 184 | 185 | class Attention(nn.Module): 186 | """Multi-head Attention block with relative position embeddings.""" 187 | 188 | def __init__( 189 | self, 190 | dim: int, 191 | num_heads: int = 8, 192 | qkv_bias: bool = True, 193 | use_rel_pos: bool = False, 194 | rel_pos_zero_init: bool = True, 195 | input_size: Optional[Tuple[int, int]] = None, 196 | ) -> None: 197 | """ 198 | Args: 199 | dim (int): Number of input channels. 200 | num_heads (int): Number of attention heads. 201 | qkv_bias (bool): If True, add a learnable bias to query, key, value. 202 | rel_pos (bool): If True, add relative positional embeddings to the attention map. 203 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. 204 | input_size (tuple(int, int) or None): Input resolution for calculating the relative 205 | positional parameter size. 206 | """ 207 | super().__init__() 208 | self.num_heads = num_heads 209 | head_dim = dim // num_heads 210 | self.scale = head_dim**-0.5 211 | 212 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 213 | self.proj = nn.Linear(dim, dim) 214 | 215 | self.use_rel_pos = use_rel_pos 216 | if self.use_rel_pos: 217 | assert ( 218 | input_size is not None 219 | ), "Input size must be provided if using relative positional encoding." 220 | # initialize relative positional embeddings 221 | self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) 222 | self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) 223 | 224 | def forward(self, x: torch.Tensor) -> torch.Tensor: 225 | B, H, W, _ = x.shape 226 | # qkv with shape (3, B, nHead, H * W, C) 227 | qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) 228 | # q, k, v with shape (B * nHead, H * W, C) 229 | q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) 230 | 231 | attn = (q * self.scale) @ k.transpose(-2, -1) 232 | 233 | if self.use_rel_pos: 234 | attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)) 235 | 236 | attn = attn.softmax(dim=-1) 237 | x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1) 238 | x = self.proj(x) 239 | 240 | return x 241 | 242 | 243 | def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]: 244 | """ 245 | Partition into non-overlapping windows with padding if needed. 246 | Args: 247 | x (tensor): input tokens with [B, H, W, C]. 248 | window_size (int): window size. 249 | 250 | Returns: 251 | windows: windows after partition with [B * num_windows, window_size, window_size, C]. 252 | (Hp, Wp): padded height and width before partition 253 | """ 254 | B, H, W, C = x.shape 255 | 256 | pad_h = (window_size - H % window_size) % window_size 257 | pad_w = (window_size - W % window_size) % window_size 258 | if pad_h > 0 or pad_w > 0: 259 | x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) 260 | Hp, Wp = H + pad_h, W + pad_w 261 | 262 | x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) 263 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 264 | return windows, (Hp, Wp) 265 | 266 | 267 | def window_unpartition( 268 | windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int] 269 | ) -> torch.Tensor: 270 | """ 271 | Window unpartition into original sequences and removing padding. 272 | Args: 273 | windows (tensor): input tokens with [B * num_windows, window_size, window_size, C]. 274 | window_size (int): window size. 275 | pad_hw (Tuple): padded height and width (Hp, Wp). 276 | hw (Tuple): original height and width (H, W) before padding. 277 | 278 | Returns: 279 | x: unpartitioned sequences with [B, H, W, C]. 280 | """ 281 | Hp, Wp = pad_hw 282 | H, W = hw 283 | B = windows.shape[0] // (Hp * Wp // window_size // window_size) 284 | x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) 285 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) 286 | 287 | if Hp > H or Wp > W: 288 | x = x[:, :H, :W, :].contiguous() 289 | return x 290 | 291 | 292 | def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: 293 | """ 294 | Get relative positional embeddings according to the relative positions of 295 | query and key sizes. 296 | Args: 297 | q_size (int): size of query q. 298 | k_size (int): size of key k. 299 | rel_pos (Tensor): relative position embeddings (L, C). 300 | 301 | Returns: 302 | Extracted positional embeddings according to relative positions. 303 | """ 304 | max_rel_dist = int(2 * max(q_size, k_size) - 1) 305 | # Interpolate rel pos if needed. 306 | if rel_pos.shape[0] != max_rel_dist: 307 | # Interpolate rel pos. 308 | rel_pos_resized = F.interpolate( 309 | rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), 310 | size=max_rel_dist, 311 | mode="linear", 312 | ) 313 | rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) 314 | else: 315 | rel_pos_resized = rel_pos 316 | 317 | # Scale the coords with short length if shapes for q and k are different. 318 | q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) 319 | k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) 320 | relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) 321 | 322 | return rel_pos_resized[relative_coords.long()] 323 | 324 | 325 | def add_decomposed_rel_pos( 326 | attn: torch.Tensor, 327 | q: torch.Tensor, 328 | rel_pos_h: torch.Tensor, 329 | rel_pos_w: torch.Tensor, 330 | q_size: Tuple[int, int], 331 | k_size: Tuple[int, int], 332 | ) -> torch.Tensor: 333 | """ 334 | Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. 335 | https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 336 | Args: 337 | attn (Tensor): attention map. 338 | q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). 339 | rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. 340 | rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. 341 | q_size (Tuple): spatial sequence size of query q with (q_h, q_w). 342 | k_size (Tuple): spatial sequence size of key k with (k_h, k_w). 343 | 344 | Returns: 345 | attn (Tensor): attention map with added relative positional embeddings. 346 | """ 347 | q_h, q_w = q_size 348 | k_h, k_w = k_size 349 | Rh = get_rel_pos(q_h, k_h, rel_pos_h) 350 | Rw = get_rel_pos(q_w, k_w, rel_pos_w) 351 | 352 | B, _, dim = q.shape 353 | r_q = q.reshape(B, q_h, q_w, dim) 354 | rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) 355 | rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) 356 | 357 | attn = ( 358 | attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] 359 | ).view(B, q_h * q_w, k_h * k_w) 360 | 361 | return attn 362 | 363 | 364 | class PatchEmbed(nn.Module): 365 | """ 366 | Image to Patch Embedding. 367 | """ 368 | 369 | def __init__( 370 | self, 371 | kernel_size: Tuple[int, int] = (16, 16), 372 | stride: Tuple[int, int] = (16, 16), 373 | padding: Tuple[int, int] = (0, 0), 374 | in_chans: int = 3, 375 | embed_dim: int = 768, 376 | ) -> None: 377 | """ 378 | Args: 379 | kernel_size (Tuple): kernel size of the projection layer. 380 | stride (Tuple): stride of the projection layer. 381 | padding (Tuple): padding size of the projection layer. 382 | in_chans (int): Number of input image channels. 383 | embed_dim (int): Patch embedding dimension. 384 | """ 385 | super().__init__() 386 | 387 | self.proj = nn.Conv2d( 388 | in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding 389 | ) 390 | 391 | def forward(self, x: torch.Tensor) -> torch.Tensor: 392 | x = self.proj(x) 393 | # B C H W -> B H W C 394 | x = x.permute(0, 2, 3, 1) 395 | return x 396 | -------------------------------------------------------------------------------- /segment_anything/modeling/mask_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | 11 | from typing import List, Tuple, Type 12 | 13 | from .common import LayerNorm2d 14 | 15 | 16 | class MaskDecoder(nn.Module): 17 | def __init__( 18 | self, 19 | *, 20 | transformer_dim: int, 21 | transformer: nn.Module, 22 | num_multimask_outputs: int = 3, 23 | activation: Type[nn.Module] = nn.GELU, 24 | iou_head_depth: int = 3, 25 | iou_head_hidden_dim: int = 256, 26 | ) -> None: 27 | """ 28 | Predicts masks given an image and prompt embeddings, using a 29 | transformer architecture. 30 | 31 | Arguments: 32 | transformer_dim (int): the channel dimension of the transformer 33 | transformer (nn.Module): the transformer used to predict masks 34 | num_multimask_outputs (int): the number of masks to predict 35 | when disambiguating masks 36 | activation (nn.Module): the type of activation to use when 37 | upscaling masks 38 | iou_head_depth (int): the depth of the MLP used to predict 39 | mask quality 40 | iou_head_hidden_dim (int): the hidden dimension of the MLP 41 | used to predict mask quality 42 | """ 43 | super().__init__() 44 | self.transformer_dim = transformer_dim 45 | self.transformer = transformer 46 | 47 | self.num_multimask_outputs = num_multimask_outputs 48 | 49 | self.iou_token = nn.Embedding(1, transformer_dim) 50 | self.num_mask_tokens = num_multimask_outputs + 1 51 | self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) 52 | 53 | self.output_upscaling = nn.Sequential( 54 | nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), 55 | LayerNorm2d(transformer_dim // 4), 56 | activation(), 57 | nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), 58 | activation(), 59 | ) 60 | self.output_hypernetworks_mlps = nn.ModuleList( 61 | [ 62 | MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) 63 | for i in range(self.num_mask_tokens) 64 | ] 65 | ) 66 | 67 | self.iou_prediction_head = MLP( 68 | transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth 69 | ) 70 | 71 | def forward( 72 | self, 73 | image_embeddings: torch.Tensor, 74 | image_pe: torch.Tensor, 75 | sparse_prompt_embeddings: torch.Tensor, 76 | dense_prompt_embeddings: torch.Tensor, 77 | multimask_output: bool, 78 | ) -> Tuple[torch.Tensor, torch.Tensor]: 79 | """ 80 | Predict masks given image and prompt embeddings. 81 | 82 | Arguments: 83 | image_embeddings (torch.Tensor): the embeddings from the image encoder 84 | image_pe (torch.Tensor): positional encoding with the shape of image_embeddings 85 | sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes 86 | dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs 87 | multimask_output (bool): Whether to return multiple masks or a single 88 | mask. 89 | 90 | Returns: 91 | torch.Tensor: batched predicted masks 92 | torch.Tensor: batched predictions of mask quality 93 | """ 94 | masks, iou_pred = self.predict_masks( 95 | image_embeddings=image_embeddings, 96 | image_pe=image_pe, 97 | sparse_prompt_embeddings=sparse_prompt_embeddings, 98 | dense_prompt_embeddings=dense_prompt_embeddings, 99 | ) 100 | 101 | # Select the correct mask or masks for output 102 | if multimask_output: 103 | mask_slice = slice(1, None) 104 | else: 105 | mask_slice = slice(0, 1) 106 | masks = masks[:, mask_slice, :, :] 107 | iou_pred = iou_pred[:, mask_slice] 108 | 109 | # Prepare output 110 | return masks, iou_pred 111 | 112 | def predict_masks( 113 | self, 114 | image_embeddings: torch.Tensor, 115 | image_pe: torch.Tensor, 116 | sparse_prompt_embeddings: torch.Tensor, 117 | dense_prompt_embeddings: torch.Tensor, 118 | ) -> Tuple[torch.Tensor, torch.Tensor]: 119 | """Predicts masks. See 'forward' for more details.""" 120 | # Concatenate output tokens 121 | output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) 122 | output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) 123 | tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) 124 | 125 | # Expand per-image data in batch direction to be per-mask 126 | src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) 127 | src = src + dense_prompt_embeddings 128 | pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) 129 | b, c, h, w = src.shape 130 | 131 | # Run the transformer 132 | hs, src = self.transformer(src, pos_src, tokens) 133 | iou_token_out = hs[:, 0, :] 134 | mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] 135 | 136 | # Upscale mask embeddings and predict masks using the mask tokens 137 | src = src.transpose(1, 2).view(b, c, h, w) 138 | upscaled_embedding = self.output_upscaling(src) 139 | hyper_in_list: List[torch.Tensor] = [] 140 | for i in range(self.num_mask_tokens): 141 | hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) 142 | hyper_in = torch.stack(hyper_in_list, dim=1) 143 | b, c, h, w = upscaled_embedding.shape 144 | masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) 145 | 146 | # Generate mask quality predictions 147 | iou_pred = self.iou_prediction_head(iou_token_out) 148 | 149 | return masks, iou_pred 150 | 151 | 152 | # Lightly adapted from 153 | # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa 154 | class MLP(nn.Module): 155 | def __init__( 156 | self, 157 | input_dim: int, 158 | hidden_dim: int, 159 | output_dim: int, 160 | num_layers: int, 161 | sigmoid_output: bool = False, 162 | ) -> None: 163 | super().__init__() 164 | self.num_layers = num_layers 165 | h = [hidden_dim] * (num_layers - 1) 166 | self.layers = nn.ModuleList( 167 | nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) 168 | ) 169 | self.sigmoid_output = sigmoid_output 170 | 171 | def forward(self, x): 172 | for i, layer in enumerate(self.layers): 173 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) 174 | if self.sigmoid_output: 175 | x = F.sigmoid(x) 176 | return x 177 | -------------------------------------------------------------------------------- /segment_anything/modeling/prompt_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torch import nn 10 | 11 | from typing import Any, Optional, Tuple, Type 12 | 13 | from .common import LayerNorm2d 14 | 15 | 16 | class PromptEncoder(nn.Module): 17 | def __init__( 18 | self, 19 | embed_dim: int, 20 | image_embedding_size: Tuple[int, int], 21 | input_image_size: Tuple[int, int], 22 | mask_in_chans: int, 23 | activation: Type[nn.Module] = nn.GELU, 24 | ) -> None: 25 | """ 26 | Encodes prompts for input to SAM's mask decoder. 27 | 28 | Arguments: 29 | embed_dim (int): The prompts' embedding dimension 30 | image_embedding_size (tuple(int, int)): The spatial size of the 31 | image embedding, as (H, W). 32 | input_image_size (int): The padded size of the image as input 33 | to the image encoder, as (H, W). 34 | mask_in_chans (int): The number of hidden channels used for 35 | encoding input masks. 36 | activation (nn.Module): The activation to use when encoding 37 | input masks. 38 | """ 39 | super().__init__() 40 | self.embed_dim = embed_dim 41 | self.input_image_size = input_image_size 42 | self.image_embedding_size = image_embedding_size 43 | self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) 44 | 45 | self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners 46 | point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)] 47 | self.point_embeddings = nn.ModuleList(point_embeddings) 48 | self.not_a_point_embed = nn.Embedding(1, embed_dim) 49 | 50 | self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1]) 51 | self.mask_downscaling = nn.Sequential( 52 | nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), 53 | LayerNorm2d(mask_in_chans // 4), 54 | activation(), 55 | nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), 56 | LayerNorm2d(mask_in_chans), 57 | activation(), 58 | nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), 59 | ) 60 | self.no_mask_embed = nn.Embedding(1, embed_dim) 61 | 62 | def get_dense_pe(self) -> torch.Tensor: 63 | """ 64 | Returns the positional encoding used to encode point prompts, 65 | applied to a dense set of points the shape of the image encoding. 66 | 67 | Returns: 68 | torch.Tensor: Positional encoding with shape 69 | 1x(embed_dim)x(embedding_h)x(embedding_w) 70 | """ 71 | return self.pe_layer(self.image_embedding_size).unsqueeze(0) 72 | 73 | def _embed_points( 74 | self, 75 | points: torch.Tensor, 76 | labels: torch.Tensor, 77 | pad: bool, 78 | ) -> torch.Tensor: 79 | """Embeds point prompts.""" 80 | points = points + 0.5 # Shift to center of pixel 81 | if pad: 82 | padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) 83 | padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) 84 | points = torch.cat([points, padding_point], dim=1) 85 | labels = torch.cat([labels, padding_label], dim=1) 86 | point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size) 87 | point_embedding[labels == -1] = 0.0 88 | point_embedding[labels == -1] += self.not_a_point_embed.weight 89 | point_embedding[labels == 0] += self.point_embeddings[0].weight 90 | point_embedding[labels == 1] += self.point_embeddings[1].weight 91 | return point_embedding 92 | 93 | def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: 94 | """Embeds box prompts.""" 95 | boxes = boxes + 0.5 # Shift to center of pixel 96 | coords = boxes.reshape(-1, 2, 2) 97 | corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size) 98 | corner_embedding[:, 0, :] += self.point_embeddings[2].weight 99 | corner_embedding[:, 1, :] += self.point_embeddings[3].weight 100 | return corner_embedding 101 | 102 | def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: 103 | """Embeds mask inputs.""" 104 | mask_embedding = self.mask_downscaling(masks) 105 | return mask_embedding 106 | 107 | def _get_batch_size( 108 | self, 109 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 110 | boxes: Optional[torch.Tensor], 111 | masks: Optional[torch.Tensor], 112 | ) -> int: 113 | """ 114 | Gets the batch size of the output given the batch size of the input prompts. 115 | """ 116 | if points is not None: 117 | return points[0].shape[0] 118 | elif boxes is not None: 119 | return boxes.shape[0] 120 | elif masks is not None: 121 | return masks.shape[0] 122 | else: 123 | return 1 124 | 125 | def _get_device(self) -> torch.device: 126 | return self.point_embeddings[0].weight.device 127 | 128 | def forward( 129 | self, 130 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 131 | boxes: Optional[torch.Tensor], 132 | masks: Optional[torch.Tensor], 133 | ) -> Tuple[torch.Tensor, torch.Tensor]: 134 | """ 135 | Embeds different types of prompts, returning both sparse and dense 136 | embeddings. 137 | 138 | Arguments: 139 | points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates 140 | and labels to embed. 141 | boxes (torch.Tensor or none): boxes to embed 142 | masks (torch.Tensor or none): masks to embed 143 | 144 | Returns: 145 | torch.Tensor: sparse embeddings for the points and boxes, with shape 146 | BxNx(embed_dim), where N is determined by the number of input points 147 | and boxes. 148 | torch.Tensor: dense embeddings for the masks, in the shape 149 | Bx(embed_dim)x(embed_H)x(embed_W) 150 | """ 151 | bs = self._get_batch_size(points, boxes, masks) 152 | sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device()) 153 | if points is not None: 154 | coords, labels = points 155 | point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) 156 | sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) 157 | if boxes is not None: 158 | box_embeddings = self._embed_boxes(boxes) 159 | sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) 160 | 161 | if masks is not None: 162 | dense_embeddings = self._embed_masks(masks) 163 | else: 164 | dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( 165 | bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] 166 | ) 167 | 168 | return sparse_embeddings, dense_embeddings 169 | 170 | 171 | class PositionEmbeddingRandom(nn.Module): 172 | """ 173 | Positional encoding using random spatial frequencies. 174 | """ 175 | 176 | def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: 177 | super().__init__() 178 | if scale is None or scale <= 0.0: 179 | scale = 1.0 180 | self.register_buffer( 181 | "positional_encoding_gaussian_matrix", 182 | scale * torch.randn((2, num_pos_feats)), 183 | ) 184 | 185 | def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: 186 | """Positionally encode points that are normalized to [0,1].""" 187 | # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape 188 | coords = 2 * coords - 1 189 | coords = coords @ self.positional_encoding_gaussian_matrix 190 | coords = 2 * np.pi * coords 191 | # outputs d_1 x ... x d_n x C shape 192 | return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) 193 | 194 | def forward(self, size: Tuple[int, int]) -> torch.Tensor: 195 | """Generate positional encoding for a grid of the specified size.""" 196 | h, w = size 197 | device: Any = self.positional_encoding_gaussian_matrix.device 198 | grid = torch.ones((h, w), device=device, dtype=torch.float32) 199 | y_embed = grid.cumsum(dim=0) - 0.5 200 | x_embed = grid.cumsum(dim=1) - 0.5 201 | y_embed = y_embed / h 202 | x_embed = x_embed / w 203 | 204 | pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) 205 | return pe.permute(2, 0, 1) # C x H x W 206 | 207 | def forward_with_coords( 208 | self, coords_input: torch.Tensor, image_size: Tuple[int, int] 209 | ) -> torch.Tensor: 210 | """Positionally encode points that are not normalized to [0,1].""" 211 | coords = coords_input.clone() 212 | coords[:, :, 0] = coords[:, :, 0] / image_size[1] 213 | coords[:, :, 1] = coords[:, :, 1] / image_size[0] 214 | return self._pe_encoding(coords.to(torch.float)) # B x N x C 215 | -------------------------------------------------------------------------------- /segment_anything/modeling/sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | 11 | from typing import Any, Dict, List, Tuple 12 | 13 | from .image_encoder import ImageEncoderViT 14 | from .mask_decoder import MaskDecoder 15 | from .prompt_encoder import PromptEncoder 16 | 17 | 18 | class Sam(nn.Module): 19 | mask_threshold: float = 0.0 20 | image_format: str = "RGB" 21 | 22 | def __init__( 23 | self, 24 | image_encoder: ImageEncoderViT, 25 | prompt_encoder: PromptEncoder, 26 | mask_decoder: MaskDecoder, 27 | pixel_mean: List[float] = [123.675, 116.28, 103.53], 28 | pixel_std: List[float] = [58.395, 57.12, 57.375], 29 | ) -> None: 30 | """ 31 | SAM predicts object masks from an image and input prompts. 32 | 33 | Arguments: 34 | image_encoder (ImageEncoderViT): The backbone used to encode the 35 | image into image embeddings that allow for efficient mask prediction. 36 | prompt_encoder (PromptEncoder): Encodes various types of input prompts. 37 | mask_decoder (MaskDecoder): Predicts masks from the image embeddings 38 | and encoded prompts. 39 | pixel_mean (list(float)): Mean values for normalizing pixels in the input image. 40 | pixel_std (list(float)): Std values for normalizing pixels in the input image. 41 | """ 42 | super().__init__() 43 | self.image_encoder = image_encoder 44 | self.prompt_encoder = prompt_encoder 45 | self.mask_decoder = mask_decoder 46 | self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) 47 | self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) 48 | 49 | @property 50 | def device(self) -> Any: 51 | return self.pixel_mean.device 52 | 53 | @torch.no_grad() 54 | def forward( 55 | self, 56 | batched_input: List[Dict[str, Any]], 57 | multimask_output: bool, 58 | ) -> List[Dict[str, torch.Tensor]]: 59 | """ 60 | Predicts masks end-to-end from provided images and prompts. 61 | If prompts are not known in advance, using SamPredictor is 62 | recommended over calling the model directly. 63 | 64 | Arguments: 65 | batched_input (list(dict)): A list over input images, each a 66 | dictionary with the following keys. A prompt key can be 67 | excluded if it is not present. 68 | 'image': The image as a torch tensor in 3xHxW format, 69 | already transformed for input to the model. 70 | 'original_size': (tuple(int, int)) The original size of 71 | the image before transformation, as (H, W). 72 | 'point_coords': (torch.Tensor) Batched point prompts for 73 | this image, with shape BxNx2. Already transformed to the 74 | input frame of the model. 75 | 'point_labels': (torch.Tensor) Batched labels for point prompts, 76 | with shape BxN. 77 | 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. 78 | Already transformed to the input frame of the model. 79 | 'mask_inputs': (torch.Tensor) Batched mask inputs to the model, 80 | in the form Bx1xHxW. 81 | multimask_output (bool): Whether the model should predict multiple 82 | disambiguating masks, or return a single mask. 83 | 84 | Returns: 85 | (list(dict)): A list over input images, where each element is 86 | as dictionary with the following keys. 87 | 'masks': (torch.Tensor) Batched binary mask predictions, 88 | with shape BxCxHxW, where B is the number of input prompts, 89 | C is determined by multimask_output, and (H, W) is the 90 | original size of the image. 91 | 'iou_predictions': (torch.Tensor) The model's predictions 92 | of mask quality, in shape BxC. 93 | 'low_res_logits': (torch.Tensor) Low resolution logits with 94 | shape BxCxHxW, where H=W=256. Can be passed as mask input 95 | to subsequent iterations of prediction. 96 | """ 97 | input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0) 98 | image_embeddings = self.image_encoder(input_images) 99 | 100 | outputs = [] 101 | for image_record, curr_embedding in zip(batched_input, image_embeddings): 102 | if "point_coords" in image_record: 103 | points = (image_record["point_coords"], image_record["point_labels"]) 104 | else: 105 | points = None 106 | sparse_embeddings, dense_embeddings = self.prompt_encoder( 107 | points=points, 108 | boxes=image_record.get("boxes", None), 109 | masks=image_record.get("mask_inputs", None), 110 | ) 111 | low_res_masks, iou_predictions = self.mask_decoder( 112 | image_embeddings=curr_embedding.unsqueeze(0), 113 | image_pe=self.prompt_encoder.get_dense_pe(), 114 | sparse_prompt_embeddings=sparse_embeddings, 115 | dense_prompt_embeddings=dense_embeddings, 116 | multimask_output=multimask_output, 117 | ) 118 | masks = self.postprocess_masks( 119 | low_res_masks, 120 | input_size=image_record["image"].shape[-2:], 121 | original_size=image_record["original_size"], 122 | ) 123 | masks = masks > self.mask_threshold 124 | outputs.append( 125 | { 126 | "masks": masks, 127 | "iou_predictions": iou_predictions, 128 | "low_res_logits": low_res_masks, 129 | } 130 | ) 131 | return outputs 132 | 133 | def postprocess_masks( 134 | self, 135 | masks: torch.Tensor, 136 | input_size: Tuple[int, ...], 137 | original_size: Tuple[int, ...], 138 | ) -> torch.Tensor: 139 | """ 140 | Remove padding and upscale masks to the original image size. 141 | 142 | Arguments: 143 | masks (torch.Tensor): Batched masks from the mask_decoder, 144 | in BxCxHxW format. 145 | input_size (tuple(int, int)): The size of the image input to the 146 | model, in (H, W) format. Used to remove padding. 147 | original_size (tuple(int, int)): The original size of the image 148 | before resizing for input to the model, in (H, W) format. 149 | 150 | Returns: 151 | (torch.Tensor): Batched masks in BxCxHxW format, where (H, W) 152 | is given by original_size. 153 | """ 154 | masks = F.interpolate( 155 | masks, 156 | (self.image_encoder.img_size, self.image_encoder.img_size), 157 | mode="bilinear", 158 | align_corners=False, 159 | ) 160 | masks = masks[..., : input_size[0], : input_size[1]] 161 | masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) 162 | return masks 163 | 164 | def preprocess(self, x: torch.Tensor) -> torch.Tensor: 165 | """Normalize pixel values and pad to a square input.""" 166 | # Normalize colors 167 | x = (x - self.pixel_mean) / self.pixel_std 168 | 169 | # Pad 170 | h, w = x.shape[-2:] 171 | padh = self.image_encoder.img_size - h 172 | padw = self.image_encoder.img_size - w 173 | x = F.pad(x, (0, padw, 0, padh)) 174 | return x 175 | -------------------------------------------------------------------------------- /segment_anything/modeling/transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import Tensor, nn 9 | 10 | import math 11 | from typing import Tuple, Type 12 | 13 | from .common import MLPBlock 14 | 15 | 16 | class TwoWayTransformer(nn.Module): 17 | def __init__( 18 | self, 19 | depth: int, 20 | embedding_dim: int, 21 | num_heads: int, 22 | mlp_dim: int, 23 | activation: Type[nn.Module] = nn.ReLU, 24 | attention_downsample_rate: int = 2, 25 | ) -> None: 26 | """ 27 | A transformer decoder that attends to an input image using 28 | queries whose positional embedding is supplied. 29 | 30 | Args: 31 | depth (int): number of layers in the transformer 32 | embedding_dim (int): the channel dimension for the input embeddings 33 | num_heads (int): the number of heads for multihead attention. Must 34 | divide embedding_dim 35 | mlp_dim (int): the channel dimension internal to the MLP block 36 | activation (nn.Module): the activation to use in the MLP block 37 | """ 38 | super().__init__() 39 | self.depth = depth 40 | self.embedding_dim = embedding_dim 41 | self.num_heads = num_heads 42 | self.mlp_dim = mlp_dim 43 | self.layers = nn.ModuleList() 44 | 45 | for i in range(depth): 46 | self.layers.append( 47 | TwoWayAttentionBlock( 48 | embedding_dim=embedding_dim, 49 | num_heads=num_heads, 50 | mlp_dim=mlp_dim, 51 | activation=activation, 52 | attention_downsample_rate=attention_downsample_rate, 53 | skip_first_layer_pe=(i == 0), 54 | ) 55 | ) 56 | 57 | self.final_attn_token_to_image = Attention( 58 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 59 | ) 60 | self.norm_final_attn = nn.LayerNorm(embedding_dim) 61 | 62 | def forward( 63 | self, 64 | image_embedding: Tensor, 65 | image_pe: Tensor, 66 | point_embedding: Tensor, 67 | ) -> Tuple[Tensor, Tensor]: 68 | """ 69 | Args: 70 | image_embedding (torch.Tensor): image to attend to. Should be shape 71 | B x embedding_dim x h x w for any h and w. 72 | image_pe (torch.Tensor): the positional encoding to add to the image. Must 73 | have the same shape as image_embedding. 74 | point_embedding (torch.Tensor): the embedding to add to the query points. 75 | Must have shape B x N_points x embedding_dim for any N_points. 76 | 77 | Returns: 78 | torch.Tensor: the processed point_embedding 79 | torch.Tensor: the processed image_embedding 80 | """ 81 | # BxCxHxW -> BxHWxC == B x N_image_tokens x C 82 | bs, c, h, w = image_embedding.shape 83 | image_embedding = image_embedding.flatten(2).permute(0, 2, 1) 84 | image_pe = image_pe.flatten(2).permute(0, 2, 1) 85 | 86 | # Prepare queries 87 | queries = point_embedding 88 | keys = image_embedding 89 | 90 | # Apply transformer blocks and final layernorm 91 | for layer in self.layers: 92 | queries, keys = layer( 93 | queries=queries, 94 | keys=keys, 95 | query_pe=point_embedding, 96 | key_pe=image_pe, 97 | ) 98 | 99 | # Apply the final attention layer from the points to the image 100 | q = queries + point_embedding 101 | k = keys + image_pe 102 | attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) 103 | queries = queries + attn_out 104 | queries = self.norm_final_attn(queries) 105 | 106 | return queries, keys 107 | 108 | 109 | class TwoWayAttentionBlock(nn.Module): 110 | def __init__( 111 | self, 112 | embedding_dim: int, 113 | num_heads: int, 114 | mlp_dim: int = 2048, 115 | activation: Type[nn.Module] = nn.ReLU, 116 | attention_downsample_rate: int = 2, 117 | skip_first_layer_pe: bool = False, 118 | ) -> None: 119 | """ 120 | A transformer block with four layers: (1) self-attention of sparse 121 | inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp 122 | block on sparse inputs, and (4) cross attention of dense inputs to sparse 123 | inputs. 124 | 125 | Arguments: 126 | embedding_dim (int): the channel dimension of the embeddings 127 | num_heads (int): the number of heads in the attention layers 128 | mlp_dim (int): the hidden dimension of the mlp block 129 | activation (nn.Module): the activation of the mlp block 130 | skip_first_layer_pe (bool): skip the PE on the first layer 131 | """ 132 | super().__init__() 133 | self.self_attn = Attention(embedding_dim, num_heads) 134 | self.norm1 = nn.LayerNorm(embedding_dim) 135 | 136 | self.cross_attn_token_to_image = Attention( 137 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 138 | ) 139 | self.norm2 = nn.LayerNorm(embedding_dim) 140 | 141 | self.mlp = MLPBlock(embedding_dim, mlp_dim, activation) 142 | self.norm3 = nn.LayerNorm(embedding_dim) 143 | 144 | self.norm4 = nn.LayerNorm(embedding_dim) 145 | self.cross_attn_image_to_token = Attention( 146 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 147 | ) 148 | 149 | self.skip_first_layer_pe = skip_first_layer_pe 150 | 151 | def forward( 152 | self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor 153 | ) -> Tuple[Tensor, Tensor]: 154 | # Self attention block 155 | if self.skip_first_layer_pe: 156 | queries = self.self_attn(q=queries, k=queries, v=queries) 157 | else: 158 | q = queries + query_pe 159 | attn_out = self.self_attn(q=q, k=q, v=queries) 160 | queries = queries + attn_out 161 | queries = self.norm1(queries) 162 | 163 | # Cross attention block, tokens attending to image embedding 164 | q = queries + query_pe 165 | k = keys + key_pe 166 | attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) 167 | queries = queries + attn_out 168 | queries = self.norm2(queries) 169 | 170 | # MLP block 171 | mlp_out = self.mlp(queries) 172 | queries = queries + mlp_out 173 | queries = self.norm3(queries) 174 | 175 | # Cross attention block, image embedding attending to tokens 176 | q = queries + query_pe 177 | k = keys + key_pe 178 | attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) 179 | keys = keys + attn_out 180 | keys = self.norm4(keys) 181 | 182 | return queries, keys 183 | 184 | 185 | class Attention(nn.Module): 186 | """ 187 | An attention layer that allows for downscaling the size of the embedding 188 | after projection to queries, keys, and values. 189 | """ 190 | 191 | def __init__( 192 | self, 193 | embedding_dim: int, 194 | num_heads: int, 195 | downsample_rate: int = 1, 196 | ) -> None: 197 | super().__init__() 198 | self.embedding_dim = embedding_dim 199 | self.internal_dim = embedding_dim // downsample_rate 200 | self.num_heads = num_heads 201 | assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim." 202 | 203 | self.q_proj = nn.Linear(embedding_dim, self.internal_dim) 204 | self.k_proj = nn.Linear(embedding_dim, self.internal_dim) 205 | self.v_proj = nn.Linear(embedding_dim, self.internal_dim) 206 | self.out_proj = nn.Linear(self.internal_dim, embedding_dim) 207 | 208 | def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: 209 | b, n, c = x.shape 210 | x = x.reshape(b, n, num_heads, c // num_heads) 211 | return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head 212 | 213 | def _recombine_heads(self, x: Tensor) -> Tensor: 214 | b, n_heads, n_tokens, c_per_head = x.shape 215 | x = x.transpose(1, 2) 216 | return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C 217 | 218 | def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: 219 | # Input projections 220 | q = self.q_proj(q) 221 | k = self.k_proj(k) 222 | v = self.v_proj(v) 223 | 224 | # Separate into heads 225 | q = self._separate_heads(q, self.num_heads) 226 | k = self._separate_heads(k, self.num_heads) 227 | v = self._separate_heads(v, self.num_heads) 228 | 229 | # Attention 230 | _, _, _, c_per_head = q.shape 231 | attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens 232 | attn = attn / math.sqrt(c_per_head) 233 | attn = torch.softmax(attn, dim=-1) 234 | 235 | # Get output 236 | out = attn @ v 237 | out = self._recombine_heads(out) 238 | out = self.out_proj(out) 239 | 240 | return out 241 | -------------------------------------------------------------------------------- /segment_anything/predictor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | 10 | from segment_anything.modeling import Sam 11 | 12 | from typing import Optional, Tuple 13 | 14 | from .utils.transforms import ResizeLongestSide 15 | 16 | 17 | class SamPredictor: 18 | def __init__( 19 | self, 20 | sam_model: Sam, 21 | ) -> None: 22 | """ 23 | Uses SAM to calculate the image embedding for an image, and then 24 | allow repeated, efficient mask prediction given prompts. 25 | 26 | Arguments: 27 | sam_model (Sam): The model to use for mask prediction. 28 | """ 29 | super().__init__() 30 | self.model = sam_model 31 | self.transform = ResizeLongestSide(sam_model.image_encoder.img_size) 32 | self.reset_image() 33 | 34 | def set_image( 35 | self, 36 | image: np.ndarray, 37 | image_format: str = "RGB", 38 | ) -> None: 39 | """ 40 | Calculates the image embeddings for the provided image, allowing 41 | masks to be predicted with the 'predict' method. 42 | 43 | Arguments: 44 | image (np.ndarray): The image for calculating masks. Expects an 45 | image in HWC uint8 format, with pixel values in [0, 255]. 46 | image_format (str): The color format of the image, in ['RGB', 'BGR']. 47 | """ 48 | assert image_format in [ 49 | "RGB", 50 | "BGR", 51 | ], f"image_format must be in ['RGB', 'BGR'], is {image_format}." 52 | if image_format != self.model.image_format: 53 | image = image[..., ::-1] 54 | 55 | # Transform the image to the form expected by the model 56 | input_image = self.transform.apply_image(image) 57 | input_image_torch = torch.as_tensor(input_image, device=self.device) 58 | input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :] 59 | 60 | self.set_torch_image(input_image_torch, image.shape[:2]) 61 | 62 | @torch.no_grad() 63 | def set_torch_image( 64 | self, 65 | transformed_image: torch.Tensor, 66 | original_image_size: Tuple[int, ...], 67 | ) -> None: 68 | """ 69 | Calculates the image embeddings for the provided image, allowing 70 | masks to be predicted with the 'predict' method. Expects the input 71 | image to be already transformed to the format expected by the model. 72 | 73 | Arguments: 74 | transformed_image (torch.Tensor): The input image, with shape 75 | 1x3xHxW, which has been transformed with ResizeLongestSide. 76 | original_image_size (tuple(int, int)): The size of the image 77 | before transformation, in (H, W) format. 78 | """ 79 | assert ( 80 | len(transformed_image.shape) == 4 81 | and transformed_image.shape[1] == 3 82 | and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size 83 | ), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}." 84 | self.reset_image() 85 | 86 | self.original_size = original_image_size 87 | self.input_size = tuple(transformed_image.shape[-2:]) 88 | input_image = self.model.preprocess(transformed_image) 89 | self.features = self.model.image_encoder(input_image) 90 | self.is_image_set = True 91 | 92 | def predict( 93 | self, 94 | point_coords: Optional[np.ndarray] = None, 95 | point_labels: Optional[np.ndarray] = None, 96 | box: Optional[np.ndarray] = None, 97 | mask_input: Optional[np.ndarray] = None, 98 | multimask_output: bool = True, 99 | return_logits: bool = False, 100 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: 101 | """ 102 | Predict masks for the given input prompts, using the currently set image. 103 | 104 | Arguments: 105 | point_coords (np.ndarray or None): A Nx2 array of point prompts to the 106 | model. Each point is in (X,Y) in pixels. 107 | point_labels (np.ndarray or None): A length N array of labels for the 108 | point prompts. 1 indicates a foreground point and 0 indicates a 109 | background point. 110 | box (np.ndarray or None): A length 4 array given a box prompt to the 111 | model, in XYXY format. 112 | mask_input (np.ndarray): A low resolution mask input to the model, typically 113 | coming from a previous prediction iteration. Has form 1xHxW, where 114 | for SAM, H=W=256. 115 | multimask_output (bool): If true, the model will return three masks. 116 | For ambiguous input prompts (such as a single click), this will often 117 | produce better masks than a single prediction. If only a single 118 | mask is needed, the model's predicted quality score can be used 119 | to select the best mask. For non-ambiguous prompts, such as multiple 120 | input prompts, multimask_output=False can give better results. 121 | return_logits (bool): If true, returns un-thresholded masks logits 122 | instead of a binary mask. 123 | 124 | Returns: 125 | (np.ndarray): The output masks in CxHxW format, where C is the 126 | number of masks, and (H, W) is the original image size. 127 | (np.ndarray): An array of length C containing the model's 128 | predictions for the quality of each mask. 129 | (np.ndarray): An array of shape CxHxW, where C is the number 130 | of masks and H=W=256. These low resolution logits can be passed to 131 | a subsequent iteration as mask input. 132 | """ 133 | if not self.is_image_set: 134 | raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") 135 | 136 | # Transform input prompts 137 | coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None 138 | if point_coords is not None: 139 | assert ( 140 | point_labels is not None 141 | ), "point_labels must be supplied if point_coords is supplied." 142 | point_coords = self.transform.apply_coords(point_coords, self.original_size) 143 | coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device) 144 | labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device) 145 | coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :] 146 | if box is not None: 147 | box = self.transform.apply_boxes(box, self.original_size) 148 | box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device) 149 | box_torch = box_torch[None, :] 150 | if mask_input is not None: 151 | mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device) 152 | mask_input_torch = mask_input_torch[None, :, :, :] 153 | 154 | masks, iou_predictions, low_res_masks = self.predict_torch( 155 | coords_torch, 156 | labels_torch, 157 | box_torch, 158 | mask_input_torch, 159 | multimask_output, 160 | return_logits=return_logits, 161 | ) 162 | 163 | masks_np = masks[0].detach().cpu().numpy() 164 | iou_predictions_np = iou_predictions[0].detach().cpu().numpy() 165 | low_res_masks_np = low_res_masks[0].detach().cpu().numpy() 166 | return masks_np, iou_predictions_np, low_res_masks_np 167 | 168 | @torch.no_grad() 169 | def predict_torch( 170 | self, 171 | point_coords: Optional[torch.Tensor], 172 | point_labels: Optional[torch.Tensor], 173 | boxes: Optional[torch.Tensor] = None, 174 | mask_input: Optional[torch.Tensor] = None, 175 | multimask_output: bool = True, 176 | return_logits: bool = False, 177 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 178 | """ 179 | Predict masks for the given input prompts, using the currently set image. 180 | Input prompts are batched torch tensors and are expected to already be 181 | transformed to the input frame using ResizeLongestSide. 182 | 183 | Arguments: 184 | point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the 185 | model. Each point is in (X,Y) in pixels. 186 | point_labels (torch.Tensor or None): A BxN array of labels for the 187 | point prompts. 1 indicates a foreground point and 0 indicates a 188 | background point. 189 | boxes (np.ndarray or None): A Bx4 array given a box prompt to the 190 | model, in XYXY format. 191 | mask_input (np.ndarray): A low resolution mask input to the model, typically 192 | coming from a previous prediction iteration. Has form Bx1xHxW, where 193 | for SAM, H=W=256. Masks returned by a previous iteration of the 194 | predict method do not need further transformation. 195 | multimask_output (bool): If true, the model will return three masks. 196 | For ambiguous input prompts (such as a single click), this will often 197 | produce better masks than a single prediction. If only a single 198 | mask is needed, the model's predicted quality score can be used 199 | to select the best mask. For non-ambiguous prompts, such as multiple 200 | input prompts, multimask_output=False can give better results. 201 | return_logits (bool): If true, returns un-thresholded masks logits 202 | instead of a binary mask. 203 | 204 | Returns: 205 | (torch.Tensor): The output masks in BxCxHxW format, where C is the 206 | number of masks, and (H, W) is the original image size. 207 | (torch.Tensor): An array of shape BxC containing the model's 208 | predictions for the quality of each mask. 209 | (torch.Tensor): An array of shape BxCxHxW, where C is the number 210 | of masks and H=W=256. These low res logits can be passed to 211 | a subsequent iteration as mask input. 212 | """ 213 | if not self.is_image_set: 214 | raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") 215 | 216 | if point_coords is not None: 217 | points = (point_coords, point_labels) 218 | else: 219 | points = None 220 | 221 | # Embed prompts 222 | sparse_embeddings, dense_embeddings = self.model.prompt_encoder( 223 | points=points, 224 | boxes=boxes, 225 | masks=mask_input, 226 | ) 227 | 228 | # Predict masks 229 | low_res_masks, iou_predictions = self.model.mask_decoder( 230 | image_embeddings=self.features, 231 | image_pe=self.model.prompt_encoder.get_dense_pe(), 232 | sparse_prompt_embeddings=sparse_embeddings, 233 | dense_prompt_embeddings=dense_embeddings, 234 | multimask_output=multimask_output, 235 | ) 236 | 237 | # Upscale the masks to the original image resolution 238 | masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size) 239 | 240 | if not return_logits: 241 | masks = masks > self.model.mask_threshold 242 | 243 | return masks, iou_predictions, low_res_masks 244 | 245 | def get_image_embedding(self) -> torch.Tensor: 246 | """ 247 | Returns the image embeddings for the currently set image, with 248 | shape 1xCxHxW, where C is the embedding dimension and (H,W) are 249 | the embedding spatial dimension of SAM (typically C=256, H=W=64). 250 | """ 251 | if not self.is_image_set: 252 | raise RuntimeError( 253 | "An image must be set with .set_image(...) to generate an embedding." 254 | ) 255 | assert self.features is not None, "Features must exist if an image has been set." 256 | return self.features 257 | 258 | @property 259 | def device(self) -> torch.device: 260 | return self.model.device 261 | 262 | def reset_image(self) -> None: 263 | """Resets the currently set image.""" 264 | self.is_image_set = False 265 | self.features = None 266 | self.orig_h = None 267 | self.orig_w = None 268 | self.input_h = None 269 | self.input_w = None 270 | -------------------------------------------------------------------------------- /segment_anything/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /segment_anything/utils/amg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | 10 | import math 11 | from copy import deepcopy 12 | from itertools import product 13 | from typing import Any, Dict, Generator, ItemsView, List, Tuple 14 | 15 | 16 | class MaskData: 17 | """ 18 | A structure for storing masks and their related data in batched format. 19 | Implements basic filtering and concatenation. 20 | """ 21 | 22 | def __init__(self, **kwargs) -> None: 23 | for v in kwargs.values(): 24 | assert isinstance( 25 | v, (list, np.ndarray, torch.Tensor) 26 | ), "MaskData only supports list, numpy arrays, and torch tensors." 27 | self._stats = dict(**kwargs) 28 | 29 | def __setitem__(self, key: str, item: Any) -> None: 30 | assert isinstance( 31 | item, (list, np.ndarray, torch.Tensor) 32 | ), "MaskData only supports list, numpy arrays, and torch tensors." 33 | self._stats[key] = item 34 | 35 | def __delitem__(self, key: str) -> None: 36 | del self._stats[key] 37 | 38 | def __getitem__(self, key: str) -> Any: 39 | return self._stats[key] 40 | 41 | def items(self) -> ItemsView[str, Any]: 42 | return self._stats.items() 43 | 44 | def filter(self, keep: torch.Tensor) -> None: 45 | for k, v in self._stats.items(): 46 | if v is None: 47 | self._stats[k] = None 48 | elif isinstance(v, torch.Tensor): 49 | self._stats[k] = v[torch.as_tensor(keep, device=v.device)] 50 | elif isinstance(v, np.ndarray): 51 | self._stats[k] = v[keep.detach().cpu().numpy()] 52 | elif isinstance(v, list) and keep.dtype == torch.bool: 53 | self._stats[k] = [a for i, a in enumerate(v) if keep[i]] 54 | elif isinstance(v, list): 55 | self._stats[k] = [v[i] for i in keep] 56 | else: 57 | raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") 58 | 59 | def cat(self, new_stats: "MaskData") -> None: 60 | for k, v in new_stats.items(): 61 | if k not in self._stats or self._stats[k] is None: 62 | self._stats[k] = deepcopy(v) 63 | elif isinstance(v, torch.Tensor): 64 | self._stats[k] = torch.cat([self._stats[k], v], dim=0) 65 | elif isinstance(v, np.ndarray): 66 | self._stats[k] = np.concatenate([self._stats[k], v], axis=0) 67 | elif isinstance(v, list): 68 | self._stats[k] = self._stats[k] + deepcopy(v) 69 | else: 70 | raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") 71 | 72 | def to_numpy(self) -> None: 73 | for k, v in self._stats.items(): 74 | if isinstance(v, torch.Tensor): 75 | self._stats[k] = v.detach().cpu().numpy() 76 | 77 | 78 | def is_box_near_crop_edge( 79 | boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0 80 | ) -> torch.Tensor: 81 | """Filter masks at the edge of a crop, but not at the edge of the original image.""" 82 | crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device) 83 | orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device) 84 | boxes = uncrop_boxes_xyxy(boxes, crop_box).float() 85 | near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0) 86 | near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0) 87 | near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge) 88 | return torch.any(near_crop_edge, dim=1) 89 | 90 | 91 | def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor: 92 | box_xywh = deepcopy(box_xyxy) 93 | box_xywh[2] = box_xywh[2] - box_xywh[0] 94 | box_xywh[3] = box_xywh[3] - box_xywh[1] 95 | return box_xywh 96 | 97 | 98 | def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]: 99 | assert len(args) > 0 and all( 100 | len(a) == len(args[0]) for a in args 101 | ), "Batched iteration must have inputs of all the same size." 102 | n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0) 103 | for b in range(n_batches): 104 | yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args] 105 | 106 | 107 | def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]: 108 | """ 109 | Encodes masks to an uncompressed RLE, in the format expected by 110 | pycoco tools. 111 | """ 112 | # Put in fortran order and flatten h,w 113 | b, h, w = tensor.shape 114 | tensor = tensor.permute(0, 2, 1).flatten(1) 115 | 116 | # Compute change indices 117 | diff = tensor[:, 1:] ^ tensor[:, :-1] 118 | change_indices = diff.nonzero() 119 | 120 | # Encode run length 121 | out = [] 122 | for i in range(b): 123 | cur_idxs = change_indices[change_indices[:, 0] == i, 1] 124 | cur_idxs = torch.cat( 125 | [ 126 | torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device), 127 | cur_idxs + 1, 128 | torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device), 129 | ] 130 | ) 131 | btw_idxs = cur_idxs[1:] - cur_idxs[:-1] 132 | counts = [] if tensor[i, 0] == 0 else [0] 133 | counts.extend(btw_idxs.detach().cpu().tolist()) 134 | out.append({"size": [h, w], "counts": counts}) 135 | return out 136 | 137 | 138 | def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray: 139 | """Compute a binary mask from an uncompressed RLE.""" 140 | h, w = rle["size"] 141 | mask = np.empty(h * w, dtype=bool) 142 | idx = 0 143 | parity = False 144 | for count in rle["counts"]: 145 | mask[idx : idx + count] = parity 146 | idx += count 147 | parity ^= True 148 | mask = mask.reshape(w, h) 149 | return mask.transpose() # Put in C order 150 | 151 | 152 | def area_from_rle(rle: Dict[str, Any]) -> int: 153 | return sum(rle["counts"][1::2]) 154 | 155 | 156 | def calculate_stability_score( 157 | masks: torch.Tensor, mask_threshold: float, threshold_offset: float 158 | ) -> torch.Tensor: 159 | """ 160 | Computes the stability score for a batch of masks. The stability 161 | score is the IoU between the binary masks obtained by thresholding 162 | the predicted mask logits at high and low values. 163 | """ 164 | # One mask is always contained inside the other. 165 | # Save memory by preventing unnecessary cast to torch.int64 166 | intersections = ( 167 | (masks > (mask_threshold + threshold_offset)) 168 | .sum(-1, dtype=torch.int16) 169 | .sum(-1, dtype=torch.int32) 170 | ) 171 | unions = ( 172 | (masks > (mask_threshold - threshold_offset)) 173 | .sum(-1, dtype=torch.int16) 174 | .sum(-1, dtype=torch.int32) 175 | ) 176 | return intersections / unions 177 | 178 | 179 | def build_point_grid(n_per_side: int) -> np.ndarray: 180 | """Generates a 2D grid of points evenly spaced in [0,1]x[0,1].""" 181 | offset = 1 / (2 * n_per_side) 182 | points_one_side = np.linspace(offset, 1 - offset, n_per_side) 183 | points_x = np.tile(points_one_side[None, :], (n_per_side, 1)) 184 | points_y = np.tile(points_one_side[:, None], (1, n_per_side)) 185 | points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2) 186 | return points 187 | 188 | 189 | def build_all_layer_point_grids( 190 | n_per_side: int, n_layers: int, scale_per_layer: int 191 | ) -> List[np.ndarray]: 192 | """Generates point grids for all crop layers.""" 193 | points_by_layer = [] 194 | for i in range(n_layers + 1): 195 | n_points = int(n_per_side / (scale_per_layer**i)) 196 | points_by_layer.append(build_point_grid(n_points)) 197 | return points_by_layer 198 | 199 | 200 | def generate_crop_boxes( 201 | im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float 202 | ) -> Tuple[List[List[int]], List[int]]: 203 | """ 204 | Generates a list of crop boxes of different sizes. Each layer 205 | has (2**i)**2 boxes for the ith layer. 206 | """ 207 | crop_boxes, layer_idxs = [], [] 208 | im_h, im_w = im_size 209 | short_side = min(im_h, im_w) 210 | 211 | # Original image 212 | crop_boxes.append([0, 0, im_w, im_h]) 213 | layer_idxs.append(0) 214 | 215 | def crop_len(orig_len, n_crops, overlap): 216 | return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops)) 217 | 218 | for i_layer in range(n_layers): 219 | n_crops_per_side = 2 ** (i_layer + 1) 220 | overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side)) 221 | 222 | crop_w = crop_len(im_w, n_crops_per_side, overlap) 223 | crop_h = crop_len(im_h, n_crops_per_side, overlap) 224 | 225 | crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)] 226 | crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)] 227 | 228 | # Crops in XYWH format 229 | for x0, y0 in product(crop_box_x0, crop_box_y0): 230 | box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)] 231 | crop_boxes.append(box) 232 | layer_idxs.append(i_layer + 1) 233 | 234 | return crop_boxes, layer_idxs 235 | 236 | 237 | def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor: 238 | x0, y0, _, _ = crop_box 239 | offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device) 240 | # Check if boxes has a channel dimension 241 | if len(boxes.shape) == 3: 242 | offset = offset.unsqueeze(1) 243 | return boxes + offset 244 | 245 | 246 | def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor: 247 | x0, y0, _, _ = crop_box 248 | offset = torch.tensor([[x0, y0]], device=points.device) 249 | # Check if points has a channel dimension 250 | if len(points.shape) == 3: 251 | offset = offset.unsqueeze(1) 252 | return points + offset 253 | 254 | 255 | def uncrop_masks( 256 | masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int 257 | ) -> torch.Tensor: 258 | x0, y0, x1, y1 = crop_box 259 | if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h: 260 | return masks 261 | # Coordinate transform masks 262 | pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0) 263 | pad = (x0, pad_x - x0, y0, pad_y - y0) 264 | return torch.nn.functional.pad(masks, pad, value=0) 265 | 266 | 267 | def remove_small_regions( 268 | mask: np.ndarray, area_thresh: float, mode: str 269 | ) -> Tuple[np.ndarray, bool]: 270 | """ 271 | Removes small disconnected regions and holes in a mask. Returns the 272 | mask and an indicator of if the mask has been modified. 273 | """ 274 | import cv2 # type: ignore 275 | 276 | assert mode in ["holes", "islands"] 277 | correct_holes = mode == "holes" 278 | working_mask = (correct_holes ^ mask).astype(np.uint8) 279 | n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8) 280 | sizes = stats[:, -1][1:] # Row 0 is background label 281 | small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh] 282 | if len(small_regions) == 0: 283 | return mask, False 284 | fill_labels = [0] + small_regions 285 | if not correct_holes: 286 | fill_labels = [i for i in range(n_labels) if i not in fill_labels] 287 | # If every region is below threshold, keep largest 288 | if len(fill_labels) == 0: 289 | fill_labels = [int(np.argmax(sizes)) + 1] 290 | mask = np.isin(regions, fill_labels) 291 | return mask, True 292 | 293 | 294 | def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]: 295 | from pycocotools import mask as mask_utils # type: ignore 296 | 297 | h, w = uncompressed_rle["size"] 298 | rle = mask_utils.frPyObjects(uncompressed_rle, h, w) 299 | rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json 300 | return rle 301 | 302 | 303 | def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor: 304 | """ 305 | Calculates boxes in XYXY format around masks. Return [0,0,0,0] for 306 | an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4. 307 | """ 308 | # torch.max below raises an error on empty inputs, just skip in this case 309 | if torch.numel(masks) == 0: 310 | return torch.zeros(*masks.shape[:-2], 4, device=masks.device) 311 | 312 | # Normalize shape to CxHxW 313 | shape = masks.shape 314 | h, w = shape[-2:] 315 | if len(shape) > 2: 316 | masks = masks.flatten(0, -3) 317 | else: 318 | masks = masks.unsqueeze(0) 319 | 320 | # Get top and bottom edges 321 | in_height, _ = torch.max(masks, dim=-1) 322 | in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :] 323 | bottom_edges, _ = torch.max(in_height_coords, dim=-1) 324 | in_height_coords = in_height_coords + h * (~in_height) 325 | top_edges, _ = torch.min(in_height_coords, dim=-1) 326 | 327 | # Get left and right edges 328 | in_width, _ = torch.max(masks, dim=-2) 329 | in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :] 330 | right_edges, _ = torch.max(in_width_coords, dim=-1) 331 | in_width_coords = in_width_coords + w * (~in_width) 332 | left_edges, _ = torch.min(in_width_coords, dim=-1) 333 | 334 | # If the mask is empty the right edge will be to the left of the left edge. 335 | # Replace these boxes with [0, 0, 0, 0] 336 | empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges) 337 | out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1) 338 | out = out * (~empty_filter).unsqueeze(-1) 339 | 340 | # Return to original shape 341 | if len(shape) > 2: 342 | out = out.reshape(*shape[:-2], 4) 343 | else: 344 | out = out[0] 345 | 346 | return out 347 | -------------------------------------------------------------------------------- /segment_anything/utils/onnx.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch.nn import functional as F 10 | 11 | from typing import Tuple 12 | 13 | from ..modeling import Sam 14 | from .amg import calculate_stability_score 15 | 16 | 17 | class SamOnnxModel(nn.Module): 18 | """ 19 | This model should not be called directly, but is used in ONNX export. 20 | It combines the prompt encoder, mask decoder, and mask postprocessing of Sam, 21 | with some functions modified to enable model tracing. Also supports extra 22 | options controlling what information. See the ONNX export script for details. 23 | """ 24 | 25 | def __init__( 26 | self, 27 | model: Sam, 28 | return_single_mask: bool, 29 | use_stability_score: bool = False, 30 | return_extra_metrics: bool = False, 31 | ) -> None: 32 | super().__init__() 33 | self.mask_decoder = model.mask_decoder 34 | self.model = model 35 | self.img_size = model.image_encoder.img_size 36 | self.return_single_mask = return_single_mask 37 | self.use_stability_score = use_stability_score 38 | self.stability_score_offset = 1.0 39 | self.return_extra_metrics = return_extra_metrics 40 | 41 | @staticmethod 42 | def resize_longest_image_size( 43 | input_image_size: torch.Tensor, longest_side: int 44 | ) -> torch.Tensor: 45 | input_image_size = input_image_size.to(torch.float32) 46 | scale = longest_side / torch.max(input_image_size) 47 | transformed_size = scale * input_image_size 48 | transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64) 49 | return transformed_size 50 | 51 | def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor: 52 | point_coords = point_coords + 0.5 53 | point_coords = point_coords / self.img_size 54 | point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords) 55 | point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding) 56 | 57 | point_embedding = point_embedding * (point_labels != -1) 58 | point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * ( 59 | point_labels == -1 60 | ) 61 | 62 | for i in range(self.model.prompt_encoder.num_point_embeddings): 63 | point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[ 64 | i 65 | ].weight * (point_labels == i) 66 | 67 | return point_embedding 68 | 69 | def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor: 70 | mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(input_mask) 71 | mask_embedding = mask_embedding + ( 72 | 1 - has_mask_input 73 | ) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1) 74 | return mask_embedding 75 | 76 | def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -> torch.Tensor: 77 | masks = F.interpolate( 78 | masks, 79 | size=(self.img_size, self.img_size), 80 | mode="bilinear", 81 | align_corners=False, 82 | ) 83 | 84 | prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to(torch.int64) 85 | masks = masks[..., : prepadded_size[0], : prepadded_size[1]] # type: ignore 86 | 87 | orig_im_size = orig_im_size.to(torch.int64) 88 | h, w = orig_im_size[0], orig_im_size[1] 89 | masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False) 90 | return masks 91 | 92 | def select_masks( 93 | self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int 94 | ) -> Tuple[torch.Tensor, torch.Tensor]: 95 | # Determine if we should return the multiclick mask or not from the number of points. 96 | # The reweighting is used to avoid control flow. 97 | score_reweight = torch.tensor( 98 | [[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)] 99 | ).to(iou_preds.device) 100 | score = iou_preds + (num_points - 2.5) * score_reweight 101 | best_idx = torch.argmax(score, dim=1) 102 | masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1) 103 | iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1) 104 | 105 | return masks, iou_preds 106 | 107 | @torch.no_grad() 108 | def forward( 109 | self, 110 | image_embeddings: torch.Tensor, 111 | point_coords: torch.Tensor, 112 | point_labels: torch.Tensor, 113 | mask_input: torch.Tensor, 114 | has_mask_input: torch.Tensor, 115 | orig_im_size: torch.Tensor, 116 | ): 117 | sparse_embedding = self._embed_points(point_coords, point_labels) 118 | dense_embedding = self._embed_masks(mask_input, has_mask_input) 119 | 120 | masks, scores = self.model.mask_decoder.predict_masks( 121 | image_embeddings=image_embeddings, 122 | image_pe=self.model.prompt_encoder.get_dense_pe(), 123 | sparse_prompt_embeddings=sparse_embedding, 124 | dense_prompt_embeddings=dense_embedding, 125 | ) 126 | 127 | if self.use_stability_score: 128 | scores = calculate_stability_score( 129 | masks, self.model.mask_threshold, self.stability_score_offset 130 | ) 131 | 132 | if self.return_single_mask: 133 | masks, scores = self.select_masks(masks, scores, point_coords.shape[1]) 134 | 135 | upscaled_masks = self.mask_postprocessing(masks, orig_im_size) 136 | 137 | if self.return_extra_metrics: 138 | stability_scores = calculate_stability_score( 139 | upscaled_masks, self.model.mask_threshold, self.stability_score_offset 140 | ) 141 | areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1) 142 | return upscaled_masks, scores, stability_scores, areas, masks 143 | 144 | return upscaled_masks, scores, masks 145 | -------------------------------------------------------------------------------- /segment_anything/utils/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torch.nn import functional as F 10 | from torchvision.transforms.functional import resize, to_pil_image # type: ignore 11 | 12 | from copy import deepcopy 13 | from typing import Tuple 14 | 15 | 16 | class ResizeLongestSide: 17 | """ 18 | Resizes images to the longest side 'target_length', as well as provides 19 | methods for resizing coordinates and boxes. Provides methods for 20 | transforming both numpy array and batched torch tensors. 21 | """ 22 | 23 | def __init__(self, target_length: int) -> None: 24 | self.target_length = target_length 25 | 26 | def apply_image(self, image: np.ndarray) -> np.ndarray: 27 | """ 28 | Expects a numpy array with shape HxWxC in uint8 format. 29 | """ 30 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) 31 | return np.array(resize(to_pil_image(image), target_size)) 32 | 33 | def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 34 | """ 35 | Expects a numpy array of length 2 in the final dimension. Requires the 36 | original image size in (H, W) format. 37 | """ 38 | old_h, old_w = original_size 39 | new_h, new_w = self.get_preprocess_shape( 40 | original_size[0], original_size[1], self.target_length 41 | ) 42 | coords = deepcopy(coords).astype(float) 43 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 44 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 45 | return coords 46 | 47 | def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 48 | """ 49 | Expects a numpy array shape Bx4. Requires the original image size 50 | in (H, W) format. 51 | """ 52 | boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) 53 | return boxes.reshape(-1, 4) 54 | 55 | def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor: 56 | """ 57 | Expects batched images with shape BxCxHxW and float format. This 58 | transformation may not exactly match apply_image. apply_image is 59 | the transformation expected by the model. 60 | """ 61 | # Expects an image in BCHW format. May not exactly match apply_image. 62 | target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length) 63 | return F.interpolate( 64 | image, target_size, mode="bilinear", align_corners=False, antialias=True 65 | ) 66 | 67 | def apply_coords_torch( 68 | self, coords: torch.Tensor, original_size: Tuple[int, ...] 69 | ) -> torch.Tensor: 70 | """ 71 | Expects a torch tensor with length 2 in the last dimension. Requires the 72 | original image size in (H, W) format. 73 | """ 74 | old_h, old_w = original_size 75 | new_h, new_w = self.get_preprocess_shape( 76 | original_size[0], original_size[1], self.target_length 77 | ) 78 | coords = deepcopy(coords).to(torch.float) 79 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 80 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 81 | return coords 82 | 83 | def apply_boxes_torch( 84 | self, boxes: torch.Tensor, original_size: Tuple[int, ...] 85 | ) -> torch.Tensor: 86 | """ 87 | Expects a torch tensor with shape Bx4. Requires the original image 88 | size in (H, W) format. 89 | """ 90 | boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size) 91 | return boxes.reshape(-1, 4) 92 | 93 | @staticmethod 94 | def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 95 | """ 96 | Compute the output size given input size and target long side length. 97 | """ 98 | scale = long_side_length * 1.0 / max(oldh, oldw) 99 | newh, neww = oldh * scale, oldw * scale 100 | neww = int(neww + 0.5) 101 | newh = int(newh + 0.5) 102 | return (newh, neww) 103 | --------------------------------------------------------------------------------