├── src
├── utils
│ ├── __init__.py
│ └── utils.py
├── scripts
│ ├── bcsunet_all.sh
│ ├── scsnet_all.sh
│ ├── reconnet_all.sh
│ └── train_all.sh
├── model
│ ├── bcsunet.py
│ ├── utils.py
│ ├── spectral_norm.py
│ ├── upsamplenet.py
│ ├── layers.py
│ └── unet.py
├── data
│ ├── base_dataset.py
│ ├── svhn.py
│ ├── emnist.py
│ └── stl10.py
├── engine
│ ├── dispatcher.py
│ └── learner.py
├── evaluate.py
├── benchmark
│ ├── reconnet
│ │ ├── net.py
│ │ ├── train.py
│ │ ├── learner.py
│ │ ├── inference_simulated.py
│ │ └── inference_spi.py
│ └── scsnet
│ │ ├── net.py
│ │ ├── train.py
│ │ ├── inference_simulated.py
│ │ ├── learner.py
│ │ └── inference_spi.py
├── inference_simulated.py
├── train.py
└── inference_spi.py
├── input
└── phi_block_binary.npy
├── requirements.txt
├── config
├── inference_config.yaml
├── scsnet_config.yaml
├── reconnet_config.yaml
├── bcsunet_STL10.yaml
├── bcsunet_SVHN.yaml
└── bcsunet_EMNIST.yaml
├── LICENSE
├── README.md
└── .gitignore
/src/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .utils import *
2 |
--------------------------------------------------------------------------------
/input/phi_block_binary.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stephenllh/bcs-unet/HEAD/input/phi_block_binary.npy
--------------------------------------------------------------------------------
/src/scripts/bcsunet_all.sh:
--------------------------------------------------------------------------------
1 | python train.py -d EMNIST
2 | python train.py -d SVHN
3 | python train.py -d STL10
--------------------------------------------------------------------------------
/src/scripts/scsnet_all.sh:
--------------------------------------------------------------------------------
1 | python -m benchmark.scsnet.train -d EMNIST
2 | python -m benchmark.scsnet.train -d SVHN
3 | python -m benchmark.scsnet.train -d STL10
--------------------------------------------------------------------------------
/src/scripts/reconnet_all.sh:
--------------------------------------------------------------------------------
1 | python -m benchmark.reconnet.train -d EMNIST
2 | python -m benchmark.reconnet.train -d SVHN
3 | # python -m benchmark.reconnet.train -d STL10
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torchmetrics==0.3.1
2 | torch==1.8.1+cu111
3 | pytorch_lightning==1.3.0
4 | scipy==1.6.3
5 | torchvision==0.9.1+cu111
6 | numpy==1.20.2
7 | matplotlib==3.4.2
8 | opencv_python==4.5.2.52
9 | albumentations==0.5.2
10 | pandas==1.2.4
11 | PyYAML==5.4.1
12 |
--------------------------------------------------------------------------------
/src/scripts/train_all.sh:
--------------------------------------------------------------------------------
1 | python -m benchmark.reconnet.train -d EMNIST
2 | python -m benchmark.reconnet.train -d SVHN
3 | python -m benchmark.reconnet.train -d STL10
4 | python -m benchmark.scsnet.train -d EMNIST
5 | python -m benchmark.scsnet.train -d SVHN
6 | python -m benchmark.scsnet.train -d STL10
7 | python train.py -d EMNIST
8 | python train.py -d SVHN
9 | python train.py -d STL10
--------------------------------------------------------------------------------
/config/inference_config.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | # This yaml file has the configurations for inference.
3 |
4 | gpu: True
5 | # dataset: "STL10"
6 | sampling_ratio: 0.125
7 | measurement_matrix: "../input/phi_block_binary.npy"
8 | real_data:
9 | - filename: "G.mat"
10 | min: 0.001
11 | max: 0.105
12 | - filename: "H.mat"
13 | min: 0.001
14 | max: 0.105
15 | - filename: "plus.mat"
16 | min: 0.01
17 | max: 0.105
18 | - filename: "S.mat"
19 | min: 0.001
20 | max: 0.105
21 | - filename: "R.mat"
22 | min: 0.001
23 | max: 0.12
24 | - filename: "F.mat"
25 | min: 0.002
26 | max: 0.4
--------------------------------------------------------------------------------
/config/scsnet_config.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | model: "SCSNet"
3 | bcs: True
4 |
5 | data_module:
6 | data_dir: "../input"
7 | batch_size: 128
8 | val_percent: 0.1
9 | num_workers: 0
10 |
11 | learner:
12 | criterion: "L1"
13 | val_metrics: ["psnr"]
14 | test_metrics: ["psnr", "ssim"]
15 | lr: 0.0003
16 | scheduler:
17 | type: "reduce_lr_on_plateau" # or "one_cycle"
18 | args_reduce_lr_on_plateau:
19 | factor: 0.3
20 | patience: 2
21 | verbose: True
22 | arg_one_cycle:
23 | pct_start: 0.3
24 | verbose: True
25 |
26 | callbacks:
27 | checkpoint:
28 | monitor: "val_loss"
29 | save_last: True
30 | filename: "best"
31 | early_stopping:
32 | monitor: "val_loss"
33 | patience: 5
34 |
35 | trainer:
36 | epochs: 50
37 | gpu: 1
38 | fp16: True
39 |
--------------------------------------------------------------------------------
/config/reconnet_config.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | model: "reconnet"
3 | img_dim: 32
4 | bcs: True
5 |
6 | data_module:
7 | data_dir: "../input"
8 | batch_size: 128
9 | val_percent: 0.1
10 | num_workers: 4
11 |
12 | learner:
13 | criterion: "L1"
14 | val_metrics: ["psnr"]
15 | test_metrics: ["psnr", "ssim"]
16 | lr: 0.001
17 | scheduler:
18 | type: "reduce_lr_on_plateau" # or "one_cycle"
19 | args_reduce_lr_on_plateau:
20 | factor: 0.3
21 | patience: 2
22 | verbose: True
23 | arg_one_cycle:
24 | pct_start: 0.3
25 | verbose: True
26 |
27 | callbacks:
28 | checkpoint:
29 | monitor: "val_loss"
30 | save_last: True
31 | filename: "best"
32 | early_stopping:
33 | monitor: "val_loss"
34 | patience: 5
35 |
36 | trainer:
37 | epochs: 40
38 | gpu: 1
39 | fp16: True
40 |
--------------------------------------------------------------------------------
/src/model/bcsunet.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | from .upsamplenet import UpsampleNet
3 | from .unet import UNet
4 | from .utils import init_weights
5 |
6 |
7 | class BCSUNet(nn.Module):
8 | """Combines simple residual upsample model and U-Net model"""
9 |
10 | def __init__(self, config):
11 | super().__init__()
12 | upsamplenet = UpsampleNet(
13 | sampling_ratio=config["sampling_ratio"],
14 | upsamplenet_config=config["net"]["upsamplenet"],
15 | )
16 | self.upsamplenet = init_weights(
17 | upsamplenet, init_type=config["net"]["upsamplenet"]["init_type"]
18 | )
19 | unet = UNet(config["net"]["unet"], input_nc=8)
20 | self.unet = init_weights(unet, init_type=config["net"]["unet"]["init_type"])
21 |
22 | def forward(self, x):
23 | x = self.upsamplenet(x)
24 | out = self.unet(x)
25 | return out
26 |
--------------------------------------------------------------------------------
/config/bcsunet_STL10.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | model: "BCS-UNet"
3 | bcs: True
4 |
5 | data_module:
6 | data_dir: "../input"
7 | batch_size: 64
8 | val_percent: 0.1
9 | num_workers: 0
10 |
11 | net:
12 | upsamplenet:
13 | init_type: "xavier"
14 | out_channels_1: 16
15 | out_channels_2: 8
16 | use_transpose_conv: False
17 | spectral_norm: True
18 | unet:
19 | init_type: "xavier"
20 | channels: 16
21 | use_dropout: True
22 |
23 | learner:
24 | intermediate_image: True
25 | criterion: "L1"
26 | val_metrics: ["psnr"]
27 | test_metrics: ["psnr", "ssim"]
28 | lr: 0.0005
29 | weight_decay: 0.0
30 | scheduler:
31 | type: "reduce_lr_on_plateau" # or "one_cycle"
32 | args_reduce_lr_on_plateau:
33 | factor: 0.3
34 | patience: 2
35 | verbose: True
36 | arg_one_cycle:
37 | pct_start: 0.3
38 | verbose: True
39 |
40 | callbacks:
41 | checkpoint:
42 | monitor: "val_loss"
43 | save_last: True
44 | filename: "best"
45 | early_stopping:
46 | monitor: "val_loss"
47 | patience: 6
48 |
49 | trainer:
50 | epochs: 50
51 | gpu: 1
52 | fp16: True
53 |
--------------------------------------------------------------------------------
/config/bcsunet_SVHN.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | model: "BCS-UNet"
3 | bcs: True
4 |
5 | data_module:
6 | data_dir: "../input"
7 | batch_size: 128
8 | val_percent: 0.1
9 | num_workers: 4
10 |
11 | net:
12 | upsamplenet:
13 | init_type: "xavier"
14 | out_channels_1: 16
15 | out_channels_2: 8
16 | use_transpose_conv: False
17 | spectral_norm: True
18 | unet:
19 | init_type: "xavier"
20 | channels: 16
21 | use_dropout: False
22 |
23 | learner:
24 | intermediate_image: True
25 | criterion: "L1"
26 | val_metrics: ["psnr"]
27 | test_metrics: ["psnr", "ssim"]
28 | lr: 0.001
29 | weight_decay: 0.001
30 | scheduler:
31 | type: "reduce_lr_on_plateau" # or "one_cycle"
32 | args_reduce_lr_on_plateau:
33 | factor: 0.3
34 | patience: 2
35 | verbose: True
36 | arg_one_cycle:
37 | pct_start: 0.3
38 | verbose: True
39 |
40 | callbacks:
41 | checkpoint:
42 | monitor: "val_loss"
43 | save_last: True
44 | filename: "best"
45 | early_stopping:
46 | monitor: "val_loss"
47 | patience: 5
48 |
49 | trainer:
50 | epochs: 50
51 | gpu: 1
52 | fp16: True
53 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Stephen Lau
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/config/bcsunet_EMNIST.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | model: "BCS-UNet"
3 | sampling_ratio: 0.125
4 | bcs: True
5 |
6 | data_module:
7 | data_dir: "../input"
8 | batch_size: 128
9 | val_percent: 0.1
10 | num_workers: 4
11 |
12 | net:
13 | upsamplenet:
14 | init_type: "xavier"
15 | out_channels_1: 16
16 | out_channels_2: 8
17 | use_transpose_conv: False
18 | spectral_norm: True
19 | unet:
20 | init_type: "xavier"
21 | channels: 16
22 | use_dropout: False
23 |
24 | learner:
25 | intermediate_image: True
26 | criterion: "L1"
27 | val_metrics: ["psnr"]
28 | test_metrics: ["psnr", "ssim"]
29 | lr: 0.0003
30 | weight_decay: 0.01
31 | scheduler:
32 | type: "reduce_lr_on_plateau" # or "one_cycle"
33 | args_reduce_lr_on_plateau:
34 | factor: 0.3
35 | patience: 2
36 | verbose: True
37 | arg_one_cycle:
38 | pct_start: 0.3
39 | verbose: True
40 |
41 | callbacks:
42 | checkpoint:
43 | monitor: "val_loss"
44 | save_last: True
45 | filename: "best"
46 | early_stopping:
47 | monitor: "val_loss"
48 | patience: 5
49 |
50 | trainer:
51 | epochs: 50
52 | gpu: 1
53 | fp16: True
54 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # BCS-UNet
2 |
3 |
4 | ## About The Project
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 | ## Getting Started
14 |
15 | To get a local copy up and running, follow these simple example steps.
16 |
17 |
18 | ### Installation
19 |
20 | 1. Clone the repo
21 | ```sh
22 | git clone https://github.com/stephenllh/bcs-unet.git
23 | ```
24 |
25 | 1. Change directory
26 | ```sh
27 | cd bcs-unet
28 | ```
29 |
30 | 2. Install packages
31 | ```sh
32 | pip install requirements.txt
33 | ```
34 |
35 |
36 |
37 | ## Usage
38 |
39 | 1. To train, run the training script
40 | ```sh
41 | cd ..
42 | python train.py
43 | ```
44 |
45 |
46 |
47 |
48 | ## License
49 |
50 | Distributed under the MIT License. See `LICENSE` for more information.
51 |
52 |
53 |
54 |
55 | ## Contact
56 |
57 | Stephen Lau - [Email](stephenlaulh@gmail.com) - [Twitter](https://twitter.com/StephenLLH) - [Kaggle](https://www.kaggle.com/faraksuli)
58 |
59 |
60 |
--------------------------------------------------------------------------------
/src/data/base_dataset.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch import nn
4 | import torch.nn.functional as F
5 |
6 |
7 | class BaseDataset:
8 | def __init__(self, sampling_ratio: float, bcs: bool):
9 | if bcs:
10 | phi = np.load("../input/phi_block_binary.npy")
11 | phi = phi[: int(sampling_ratio * 16)]
12 | phi = torch.FloatTensor(phi)
13 | # print("phi_shape", phi.shape)
14 | self.cs_operator = BCSOperator(phi)
15 | else: # TODO: prepare for CCS
16 | pass
17 | # phi = np.load("../input/________.npy")
18 | # self.cs_operator = CCSOperator(phi)
19 |
20 |
21 | class BCSOperator(nn.Module):
22 | """
23 | This CNN is not trainable. It serves as a block compressive sensing operator
24 | by utilizing its optimized convolution operations.
25 | """
26 |
27 | def __init__(self, phi):
28 | super().__init__()
29 | self.register_buffer("phi", phi)
30 |
31 | def forward(self, x):
32 | out = F.conv2d(x, self.phi, stride=4)
33 | return out
34 |
35 |
36 | class CCSOperator(nn.Module):
37 | """
38 | This CNN is not trainable. It serves as a full-image conventional compressive sensing operator
39 | by utilizing its optimized convolution operations.
40 | """
41 |
42 | def __init__(self, phi):
43 | super().__init__()
44 | self.register_buffer("phi", phi)
45 |
46 | def forward(self, x):
47 | assert x.shape[-1] == self.phi.shape[-1]
48 | out = F.conv2d(x, self.phi, stride=1)
49 | return out
50 |
--------------------------------------------------------------------------------
/src/engine/dispatcher.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | from torch.optim import lr_scheduler
3 | import torchmetrics
4 |
5 |
6 | def get_scheduler(optimizer, config):
7 | scheduler_config = config["learner"]["scheduler"]
8 | scheduler_type = scheduler_config["type"]
9 |
10 | if scheduler_type == "reduce_lr_on_plateau":
11 | scheduler = lr_scheduler.ReduceLROnPlateau(
12 | optimizer, **scheduler_config["args_reduce_lr_on_plateau"]
13 | )
14 |
15 | # TODO: need to work on this (calculate the total number of steps)
16 | elif scheduler_type == "one_cycle":
17 | scheduler = lr_scheduler.OneCycleLR(
18 | optimizer,
19 | **scheduler_config["arg_one_cycle"],
20 | epochs=config["trainer"]["epochs"],
21 | steps_per_epoch=100
22 | )
23 |
24 | else:
25 | raise NotImplementedError("This scheduler is not implemented.")
26 |
27 | return scheduler
28 |
29 |
30 | def get_criterion(config):
31 | if config["learner"]["criterion"] == "cross_entropy":
32 | return nn.CrossEntropyLoss()
33 |
34 | elif config["learner"]["criterion"] == "L1":
35 | return nn.L1Loss()
36 |
37 | else:
38 | raise NotImplementedError("This loss function is not implemented.")
39 |
40 |
41 | def get_metrics(metric_name, config):
42 | if metric_name == "psnr":
43 | return torchmetrics.PSNR(data_range=1.0, dim=(-2, -1))
44 |
45 | elif metric_name == "ssim":
46 | return torchmetrics.SSIM(data_range=1.0)
47 |
48 | else:
49 | raise NotImplementedError("This metric is not implemented.")
50 |
--------------------------------------------------------------------------------
/src/evaluate.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import pytorch_lightning as pl
4 | from pytorch_lightning.utilities.seed import seed_everything
5 | from data.emnist import EMNISTDataModule
6 | from data.svhn import SVHNDataModule
7 | from data.stl10 import STL10DataModule
8 | from engine.learner import BCSUNetLearner
9 | from utils import load_config
10 |
11 |
12 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
13 |
14 | parser = argparse.ArgumentParser(description="Wheat detection with EfficientDet")
15 | parser.add_argument("-d", "--dataset", type=str, help="'EMNIST', 'SVHN', or 'STL10'")
16 | parser.add_argument(
17 | "-s",
18 | "--sampling_ratio",
19 | type=float,
20 | required=True,
21 | help="Sampling ratio in percentage",
22 | )
23 |
24 | args = parser.parse_args()
25 |
26 |
27 | def run():
28 | seed_everything(seed=0, workers=True)
29 | path = f"../logs/BCSUNet_{args.dataset}_{int(args.sampling_ratio * 10000)}"
30 | config = load_config(f"{path}/version_0/hparams.yaml")
31 |
32 | if args.dataset == "EMNIST":
33 | data_module = EMNISTDataModule(config)
34 | elif args.dataset == "SVHN":
35 | data_module = SVHNDataModule(config)
36 | elif args.dataset == "STL10":
37 | data_module = STL10DataModule(config)
38 | else:
39 | raise NotImplementedError
40 |
41 | PATH = f"{path}/version_0/checkpoints/best.ckpt"
42 | learner = BCSUNetLearner.load_from_checkpoint(PATH, config)
43 |
44 | trainer = pl.Trainer(
45 | gpus=1,
46 | default_root_dir="../",
47 | logger=False,
48 | )
49 | trainer.test(learner, datamodule=data_module)
50 |
51 |
52 | if __name__ == "__main__":
53 | run()
54 |
--------------------------------------------------------------------------------
/src/benchmark/reconnet/net.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | class ReconNet(nn.Module):
6 | def __init__(self, num_measurements, img_dim):
7 | super().__init__()
8 | self.img_dim = img_dim
9 | self.linear = nn.Linear(num_measurements, img_dim * img_dim)
10 | self.bn = nn.BatchNorm1d(img_dim * img_dim)
11 | self.convs = nn.Sequential(
12 | ConvBlock(in_channels=1, out_channels=64, kernel_size=11),
13 | ConvBlock(in_channels=64, out_channels=32, kernel_size=1),
14 | ConvBlock(in_channels=32, out_channels=1, kernel_size=7),
15 | ConvBlock(in_channels=1, out_channels=64, kernel_size=11),
16 | ConvBlock(in_channels=64, out_channels=32, kernel_size=1),
17 | )
18 | self.final_conv = nn.Conv2d(32, 1, kernel_size=7, padding=7 // 2)
19 |
20 | def forward(self, x):
21 | if x.dim() == 4: # BCS
22 | x = x.view(x.shape[0], -1)
23 | x = x.unsqueeze(dim=1)
24 | x = self.linear(x)
25 | x = x.view(-1, 1, self.img_dim, self.img_dim)
26 | x = self.convs(x)
27 | x = self.final_conv(x)
28 | out = torch.sigmoid(x)
29 | return out
30 |
31 |
32 | class ConvBlock(nn.Module):
33 | def __init__(self, in_channels, out_channels, kernel_size, stride=1):
34 | super().__init__()
35 | self.conv = nn.Conv2d(
36 | in_channels,
37 | out_channels,
38 | kernel_size=kernel_size,
39 | stride=stride,
40 | padding=kernel_size // 2,
41 | bias=True,
42 | )
43 | self.relu = nn.ReLU(inplace=True)
44 | self.bn = nn.BatchNorm2d(out_channels)
45 |
46 | def forward(self, x):
47 | x = self.conv(x)
48 | x = self.bn(x)
49 | out = self.relu(x)
50 | return out
51 |
--------------------------------------------------------------------------------
/src/model/utils.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 |
3 |
4 | def init_weights(net, init_type, init_gain=0.02):
5 | """
6 | Initialize network weights.
7 | Parameters:
8 | net (network) -- network to be initialized
9 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
10 | init_gain (float) -- scaling factor for normal, xavier and orthogonal.
11 | """
12 |
13 | def init_func(m): # define the initialization function
14 | classname = m.__class__.__name__
15 | if hasattr(m, "weight") and (
16 | classname.find("Conv") != -1 or classname.find("Linear") != -1
17 | ):
18 |
19 | if init_type == "normal":
20 | nn.init.normal_(m.weight.data, 0.0, init_gain)
21 |
22 | elif init_type == "xavier":
23 | nn.init.xavier_normal_(m.weight.data, gain=init_gain)
24 |
25 | elif init_type == "kaiming":
26 | nn.init.kaiming_normal_(m.weight.data, a=0, mode="fan_in")
27 |
28 | elif init_type == "orthogonal":
29 | nn.init.orthogonal_(m.weight.data, gain=init_gain)
30 |
31 | else:
32 | raise NotImplementedError(
33 | "Initialization method [%s] is not implemented" % init_type
34 | )
35 |
36 | if hasattr(m, "bias") and m.bias is not None:
37 | nn.init.constant_(m.bias.data, 0.0)
38 |
39 | elif (
40 | classname.find("BatchNorm2d") != -1
41 | ): # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
42 | nn.init.normal_(m.weight.data, 1.0, init_gain)
43 | nn.init.constant_(m.bias.data, 0.0)
44 |
45 | # print("Initialize network with %s" % init_type)
46 | net.apply(init_func) # apply the initialization function
47 | return net
48 |
--------------------------------------------------------------------------------
/src/utils/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import yaml
3 | import torch
4 | import math
5 | import cv2
6 |
7 |
8 | def voltage2pixel(y, phi, low, high):
9 | """
10 | Converts the measurement tensor / vector from voltage scale to pixel scale,
11 | so that the neural network can understand.
12 | y - measurement
13 | phi - block measurement matrix
14 | low - lowest voltage pixel
15 | high - highest voltage pixel
16 | """
17 |
18 | if type(y) == "numpy.ndarray":
19 | y = torch.from_numpy(y).float()
20 |
21 | phi = torch.from_numpy(phi).float()
22 |
23 | if y.dim() == 4 and y.shape[0] == 1:
24 | y = y.squeeze(dim=0)
25 |
26 | if phi.dim() == 4 and phi.shape[1] == 1:
27 | phi = phi.squeeze(dim=1)
28 |
29 | term1 = y / (high - low)
30 | term2 = (phi.sum(dim=(1, 2)) * low / (high - low)).unsqueeze(-1).unsqueeze(-1)
31 |
32 | y_pixel_scale = term1 - term2
33 | return y_pixel_scale
34 |
35 |
36 | def load_config(config_path):
37 | with open(os.path.join(config_path)) as file:
38 | config = yaml.safe_load(file)
39 | return config
40 |
41 |
42 | def reshape_into_block(y, sr: float, block_size=4):
43 | c = int(sr * block_size ** 2)
44 | h = w = int(math.sqrt(y.shape[0] // c))
45 | y = y.reshape(h, w, c)
46 | y = y.transpose((2, 0, 1))
47 | return y
48 |
49 |
50 | def create_patches(root_dir, size):
51 | print(f"Creating {size}x{size} image patches for STL10 test set.")
52 | dataset_dir = os.path.join(root_dir, "test_images")
53 | save_dir = os.path.join(root_dir, "test_images_32x32")
54 |
55 | if not os.path.exists(save_dir):
56 | os.mkdir(save_dir)
57 |
58 | if not os.path.exists(dataset_dir):
59 | os.mkdir(dataset_dir)
60 |
61 | filenames = os.listdir(dataset_dir)
62 | for filename in filenames:
63 | img = cv2.imread(f"{dataset_dir}/{filename}")
64 | img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
65 | img = img.reshape(img.shape[0] // size, size, img.shape[1] // size, size)
66 | img = img.transpose(0, 2, 1, 3).reshape(-1, size, size)
67 | for idx, img_patch in enumerate(img):
68 | save_filename = f"{save_dir}/{filename.split('.')[0]}_{idx}.png"
69 | cv2.imwrite(save_filename, img_patch)
70 |
--------------------------------------------------------------------------------
/src/benchmark/scsnet/net.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | class SCSNetInit(nn.Module):
6 | """The "initial reconstruction network" of SCSNet"""
7 |
8 | def __init__(self, in_channels, block_size=4):
9 | super().__init__()
10 | self.block_size = block_size
11 | self.conv = nn.Conv2d(in_channels, block_size ** 2, kernel_size=1)
12 |
13 | def forward(self, x):
14 | x = self.conv(x)
15 | out = self._permute(x)
16 | return out
17 |
18 | def _permute(self, x):
19 | B, C, H, W = x.shape
20 | x = x.permute(0, 2, 3, 1)
21 | x = x.view(B, H, W, self.block_size, self.block_size)
22 | x = x.permute(0, 1, 3, 2, 4).contiguous()
23 | out = x.view(-1, 1, H * self.block_size, W * self.block_size)
24 | return out
25 |
26 |
27 | class SCSNetDeep(nn.Module):
28 | """The "deep reconstruction network" of SCSNet"""
29 |
30 | def __init__(self):
31 | super().__init__()
32 | middle_convs = [
33 | ConvBlock(in_channels=32, out_channels=32, kernel_size=3) for _ in range(13)
34 | ]
35 | self.convs = nn.Sequential(
36 | ConvBlock(in_channels=1, out_channels=128, kernel_size=3),
37 | ConvBlock(in_channels=128, out_channels=32, kernel_size=3),
38 | *middle_convs,
39 | ConvBlock(in_channels=32, out_channels=128, kernel_size=3),
40 | nn.Conv2d(
41 | in_channels=128,
42 | out_channels=1,
43 | kernel_size=3,
44 | padding=1,
45 | bias=False,
46 | ),
47 | )
48 |
49 | def forward(self, x):
50 | out = self.convs(x)
51 | return torch.sigmoid(x + out)
52 |
53 |
54 | class ConvBlock(nn.Module):
55 | def __init__(self, in_channels, out_channels, kernel_size, stride=1):
56 | super().__init__()
57 | self.conv = nn.Conv2d(
58 | in_channels,
59 | out_channels,
60 | kernel_size=kernel_size,
61 | stride=stride,
62 | padding=kernel_size // 2,
63 | bias=True,
64 | )
65 | self.relu = nn.ReLU(inplace=True)
66 | self.bn = nn.BatchNorm2d(out_channels)
67 |
68 | def forward(self, x):
69 | x = self.conv(x)
70 | x = self.bn(x)
71 | out = self.relu(x)
72 | return out
73 |
--------------------------------------------------------------------------------
/src/inference_simulated.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import argparse
4 | from pathlib import Path
5 | import warnings
6 | import cv2
7 | import scipy.ndimage
8 | from data.emnist import EMNISTDataModule
9 | from data.svhn import SVHNDataModule
10 | from data.stl10 import STL10DataModule
11 | from engine.learner import BCSUNetLearner
12 | from utils import load_config
13 |
14 |
15 | parser = argparse.ArgumentParser()
16 | parser.add_argument(
17 | "-d", "--dataset", type=str, required=True, help="'EMNIST', 'SVHN', or 'STL10'"
18 | )
19 | parser.add_argument(
20 | "-s",
21 | "--sampling_ratio",
22 | type=float,
23 | required=True,
24 | help="Sampling ratio in percentage",
25 | )
26 | args = parser.parse_args()
27 |
28 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
29 | warnings.simplefilter("ignore")
30 |
31 |
32 | def run():
33 | dataset = args.dataset
34 | sr = args.sampling_ratio
35 | checkpoint_path = (
36 | f"../logs/BCSUNet_{dataset}_{int(sr * 100):04d}/best/checkpoints/last.ckpt"
37 | )
38 | train_config_path = os.path.join(
39 | Path(checkpoint_path).parent.parent, "hparams.yaml"
40 | )
41 | train_config = load_config(train_config_path)
42 | train_config["sampling_ratio"] = sr / 100
43 |
44 | if dataset == "EMNIST":
45 | data_module = EMNISTDataModule(train_config)
46 | elif dataset == "SVHN":
47 | data_module = SVHNDataModule(train_config)
48 | elif dataset == "STL10":
49 | data_module = STL10DataModule(train_config)
50 |
51 | learner = BCSUNetLearner.load_from_checkpoint(
52 | checkpoint_path=checkpoint_path, config=train_config, strict=False
53 | )
54 |
55 | message = f"Inference: BCS-UNet on {dataset} dataset. Sampling ratio = {train_config['sampling_ratio']}"
56 | print(message)
57 |
58 | data_module.setup()
59 | ds = data_module.test_dataset
60 |
61 | directory = f"../inference_images/BCSUNet/{dataset}/{int(sr * 100):04d}"
62 | os.makedirs(directory, exist_ok=True)
63 |
64 | for i in np.linspace(0, len(ds) - 1, 30, dtype=int):
65 | input = ds[i][0].unsqueeze(0)
66 | out = learner(input).squeeze().squeeze().detach().numpy()
67 | out = scipy.ndimage.zoom(out, 8, order=0, mode="nearest")
68 | cv2.imwrite(f"{directory}/{i}.png", out * 255)
69 |
70 | print("Done.")
71 |
72 |
73 | if __name__ == "__main__":
74 | run()
75 |
--------------------------------------------------------------------------------
/src/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import pytorch_lightning as pl
4 | from pytorch_lightning.callbacks import (
5 | ModelCheckpoint,
6 | EarlyStopping,
7 | LearningRateMonitor,
8 | )
9 | from pytorch_lightning.loggers import TensorBoardLogger
10 | from pytorch_lightning.utilities.seed import seed_everything
11 | from data.emnist import EMNISTDataModule
12 | from data.svhn import SVHNDataModule
13 | from data.stl10 import STL10DataModule
14 | from engine.learner import BCSUNetLearner
15 | from utils import load_config
16 |
17 |
18 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
19 |
20 | parser = argparse.ArgumentParser()
21 | parser.add_argument(
22 | "-d", "--dataset", type=str, required=True, help="'EMNIST', 'SVHN', or 'STL10'"
23 | )
24 | parser.add_argument(
25 | "-s",
26 | "--sampling_ratio",
27 | type=float,
28 | required=True,
29 | help="Sampling ratio in percentage",
30 | )
31 | args = parser.parse_args()
32 |
33 |
34 | def run():
35 | seed_everything(seed=0, workers=True)
36 |
37 | config = load_config(f"../config/bcsunet_{args.dataset}.yaml")
38 | config["sampling_ratio"] = args.sampling_ratio / 100
39 |
40 | if args.dataset == "EMNIST":
41 | data_module = EMNISTDataModule(config)
42 | elif args.dataset == "SVHN":
43 | data_module = SVHNDataModule(config)
44 | elif args.dataset == "STL10":
45 | data_module = STL10DataModule(config)
46 |
47 | learner = BCSUNetLearner(config)
48 | callbacks = [
49 | ModelCheckpoint(**config["callbacks"]["checkpoint"]),
50 | EarlyStopping(**config["callbacks"]["early_stopping"]),
51 | LearningRateMonitor(),
52 | ]
53 | log_name = f"BCSUNet_{args.dataset}_{int(config['sampling_ratio'] * 10000):04d}"
54 | logger = TensorBoardLogger(save_dir="../logs", name=log_name)
55 |
56 | message = f"Running BCS-UNet on {args.dataset} dataset. Sampling ratio = {config['sampling_ratio']}"
57 | print("-" * 100)
58 | print(message)
59 | print("-" * 100)
60 |
61 | trainer = pl.Trainer(
62 | gpus=config["trainer"]["gpu"],
63 | max_epochs=config["trainer"]["epochs"],
64 | default_root_dir="../",
65 | callbacks=callbacks,
66 | precision=(16 if config["trainer"]["fp16"] else 32),
67 | logger=logger,
68 | )
69 | trainer.fit(learner, data_module)
70 | trainer.test(learner, datamodule=data_module, ckpt_path="best")
71 |
72 |
73 | if __name__ == "__main__":
74 | run()
75 |
--------------------------------------------------------------------------------
/src/benchmark/scsnet/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import pytorch_lightning as pl
4 | from pytorch_lightning.callbacks import (
5 | ModelCheckpoint,
6 | EarlyStopping,
7 | LearningRateMonitor,
8 | )
9 | from pytorch_lightning.loggers import TensorBoardLogger
10 | from pytorch_lightning.utilities.seed import seed_everything
11 | from data.emnist import EMNISTDataModule
12 | from data.svhn import SVHNDataModule
13 | from data.stl10 import STL10DataModule
14 | from .learner import SCSNetLearner
15 | from utils import load_config
16 |
17 |
18 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
19 |
20 | parser = argparse.ArgumentParser()
21 | parser.add_argument("-d", "--dataset", type=str, help="'EMNIST', 'SVHN', or 'STL10'")
22 | parser.add_argument(
23 | "-s",
24 | "--sampling_ratio",
25 | type=float,
26 | required=True,
27 | help="Sampling ratio in percentage",
28 | )
29 | args = parser.parse_args()
30 |
31 |
32 | def run():
33 | seed_everything(seed=0, workers=True)
34 |
35 | config = load_config("../config/scsnet_config.yaml")
36 | config["sampling_ratio"] = args.sampling_ratio / 100
37 |
38 | if args.dataset == "EMNIST":
39 | data_module = EMNISTDataModule(config)
40 | elif args.dataset == "SVHN":
41 | data_module = SVHNDataModule(config)
42 | elif args.dataset == "STL10":
43 | config["data_module"]["batch_size"] = 64
44 | data_module = STL10DataModule(config)
45 | else:
46 | raise NotImplementedError
47 |
48 | learner = SCSNetLearner(config)
49 | callbacks = [
50 | ModelCheckpoint(**config["callbacks"]["checkpoint"]),
51 | EarlyStopping(**config["callbacks"]["early_stopping"]),
52 | LearningRateMonitor(),
53 | ]
54 |
55 | log_name = f"SCSNet_{args.dataset}_{int(config['sampling_ratio'] * 10000):04d}"
56 | logger = TensorBoardLogger(save_dir="../logs", name=log_name)
57 |
58 | message = f"Running SCSNet on {args.dataset} dataset. Sampling ratio = {config['sampling_ratio']}"
59 | print("-" * 100)
60 | print(message)
61 | print("-" * 100)
62 |
63 | trainer = pl.Trainer(
64 | gpus=config["trainer"]["gpu"],
65 | max_epochs=config["trainer"]["epochs"],
66 | default_root_dir="../",
67 | callbacks=callbacks,
68 | precision=(16 if config["trainer"]["fp16"] else 32),
69 | logger=logger,
70 | )
71 | trainer.fit(learner, data_module)
72 | trainer.test(learner, datamodule=data_module, ckpt_path="best")
73 |
74 |
75 | if __name__ == "__main__":
76 | run()
77 |
--------------------------------------------------------------------------------
/src/benchmark/scsnet/inference_simulated.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import numpy as np
4 | from pathlib import Path
5 | import warnings
6 | import cv2
7 | import scipy.ndimage
8 | from data.emnist import EMNISTDataModule
9 | from data.svhn import SVHNDataModule
10 | from data.stl10 import STL10DataModule
11 | from .learner import SCSNetLearner
12 | from utils import load_config
13 |
14 |
15 | parser = argparse.ArgumentParser()
16 | parser.add_argument(
17 | "-d", "--dataset", type=str, required=True, help="'EMNIST', 'SVHN', or 'STL10'"
18 | )
19 | parser.add_argument(
20 | "-s",
21 | "--sampling_ratio",
22 | type=float,
23 | required=True,
24 | help="Sampling ratio in percentage",
25 | )
26 | args = parser.parse_args()
27 |
28 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
29 | warnings.simplefilter("ignore")
30 |
31 |
32 | def run():
33 | # inference_config = load_config("../config/inference_config.yaml")
34 | dataset = args.dataset
35 | sr = args.sampling_ratio
36 | checkpoint_path = (
37 | f"../logs/SCSNet_{dataset}_{int(sr * 100):04d}/best/checkpoints/last.ckpt"
38 | )
39 | train_config_path = os.path.join(
40 | Path(checkpoint_path).parent.parent, "hparams.yaml"
41 | )
42 | train_config = load_config(train_config_path)
43 | train_config["sampling_ratio"] = sr / 100
44 |
45 | if dataset == "EMNIST":
46 | data_module = EMNISTDataModule(train_config)
47 | train_config["img_dim"] = 32
48 | elif dataset == "SVHN":
49 | data_module = SVHNDataModule(train_config)
50 | train_config["img_dim"] = 32
51 | elif dataset == "STL10":
52 | data_module = STL10DataModule(train_config)
53 | train_config["img_dim"] = 96
54 |
55 | learner = SCSNetLearner.load_from_checkpoint(
56 | checkpoint_path=checkpoint_path, config=train_config, strict=False
57 | )
58 |
59 | message = f"Inference: SCSNet on {dataset} dataset. Sampling ratio = {train_config['sampling_ratio']}"
60 | print(message)
61 |
62 | data_module.setup()
63 | ds = data_module.test_dataset
64 |
65 | directory = f"../inference_images/SCSNet/{dataset}/{int(sr * 100):04d}"
66 | os.makedirs(directory, exist_ok=True)
67 |
68 | for i in np.linspace(0, len(ds) - 1, 30, dtype=int):
69 | input = ds[i][0].unsqueeze(0)
70 | out = learner(input).squeeze().squeeze().detach().numpy()
71 | out = scipy.ndimage.zoom(out, 8, order=0, mode="nearest")
72 | cv2.imwrite(f"{directory}/{i}.png", out * 255)
73 |
74 | print("Done.")
75 |
76 |
77 | if __name__ == "__main__":
78 | run()
79 |
--------------------------------------------------------------------------------
/src/benchmark/reconnet/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import pytorch_lightning as pl
4 | from pytorch_lightning.callbacks import (
5 | ModelCheckpoint,
6 | EarlyStopping,
7 | LearningRateMonitor,
8 | )
9 | from pytorch_lightning.loggers import TensorBoardLogger
10 | from pytorch_lightning.utilities.seed import seed_everything
11 | from data.emnist import EMNISTDataModule
12 | from data.svhn import SVHNDataModule
13 | from data.stl10 import STL10DataModule
14 | from .learner import ReconNetLearner
15 | from utils import load_config
16 |
17 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
18 |
19 | parser = argparse.ArgumentParser()
20 | parser.add_argument("-d", "--dataset", type=str, help="'EMNIST', 'SVHN', or 'STL10'")
21 | parser.add_argument(
22 | "-s",
23 | "--sampling_ratio",
24 | type=float,
25 | required=True,
26 | help="Sampling ratio in percentage",
27 | )
28 | args = parser.parse_args()
29 |
30 |
31 | def run():
32 | seed_everything(seed=0, workers=True)
33 |
34 | config = load_config("../config/reconnet_config.yaml")
35 | config["sampling_ratio"] = args.sampling_ratio / 100
36 |
37 | if args.dataset == "EMNIST":
38 | data_module = EMNISTDataModule(config)
39 | elif args.dataset == "SVHN":
40 | data_module = SVHNDataModule(config)
41 | elif args.dataset == "STL10":
42 | config["data_module"]["num_workers"] = 0
43 | data_module = STL10DataModule(config, reconnet=True)
44 | else:
45 | raise NotImplementedError
46 |
47 | learner = ReconNetLearner(config)
48 | callbacks = [
49 | ModelCheckpoint(**config["callbacks"]["checkpoint"]),
50 | EarlyStopping(**config["callbacks"]["early_stopping"]),
51 | LearningRateMonitor(),
52 | ]
53 |
54 | sampling_ratio = config["sampling_ratio"]
55 | log_name = f"ReconNet_{args.dataset}_{int(sampling_ratio * 10000):04d}"
56 | logger = TensorBoardLogger(save_dir="../logs", name=log_name)
57 |
58 | message = f"Running ReconNet on {args.dataset} dataset. Sampling ratio = {sampling_ratio * 100}%"
59 | print("-" * 100)
60 | print(message)
61 | print("-" * 100)
62 |
63 | trainer = pl.Trainer(
64 | gpus=config["trainer"]["gpu"],
65 | max_epochs=config["trainer"]["epochs"],
66 | default_root_dir="../",
67 | progress_bar_refresh_rate=20,
68 | callbacks=callbacks,
69 | precision=(16 if config["trainer"]["fp16"] else 32),
70 | logger=logger,
71 | )
72 | trainer.fit(learner, data_module)
73 | trainer.test(learner, datamodule=data_module, ckpt_path="best")
74 |
75 |
76 | if __name__ == "__main__":
77 | run()
78 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
98 | __pypackages__/
99 |
100 | # Celery stuff
101 | celerybeat-schedule
102 | celerybeat.pid
103 |
104 | # SageMath parsed files
105 | *.sage.py
106 |
107 | # Environments
108 | .env
109 | .venv
110 | env/
111 | venv/
112 | ENV/
113 | env.bak/
114 | venv.bak/
115 |
116 | # Spyder project settings
117 | .spyderproject
118 | .spyproject
119 |
120 | # Rope project settings
121 | .ropeproject
122 |
123 | # mkdocs documentation
124 | /site
125 |
126 | # mypy
127 | .mypy_cache/
128 | .dmypy.json
129 | dmypy.json
130 |
131 | # Pyre type checker
132 | .pyre/
133 |
134 | # pytype static type analyzer
135 | .pytype/
136 |
137 | # Cython debug symbols
138 | cython_debug/
139 |
140 | # VSCode
141 | **/.vscode
142 | **/__pycache__
143 |
144 | # Input data and models
145 | **/input/**
146 | !/input/phi_block_binary.npy
147 | **/logs
148 | **/lightning_logs
149 | **/inference_input
150 | **/inference_images
151 | **/wandb
152 | **/saved_models
153 | **/src/tests
154 |
155 | # Data files
156 | *.csv
157 | *.h5
158 | *.pkl
159 | *.pth
160 | *.bin
161 | *.pyc
162 | *.png
163 | *.ckpt
164 |
--------------------------------------------------------------------------------
/src/engine/learner.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import pytorch_lightning as pl
3 | from model.bcsunet import BCSUNet
4 | from engine.dispatcher import get_scheduler, get_criterion, get_metrics
5 |
6 |
7 | class BCSUNetLearner(pl.LightningModule):
8 | def __init__(self, config):
9 | super().__init__()
10 | self.net = BCSUNet(config)
11 | self.config = config
12 | self.criterion = get_criterion(config)
13 | self._set_metrics(config)
14 | self.save_hyperparameters(config)
15 |
16 | def forward(self, inputs):
17 | reconstructed_image = self.net(inputs)
18 | return reconstructed_image
19 |
20 | def configure_optimizers(self):
21 | optimizer = torch.optim.AdamW(
22 | self.parameters(),
23 | lr=self.config["learner"]["lr"],
24 | weight_decay=self.config["learner"]["weight_decay"],
25 | )
26 | scheduler = get_scheduler(optimizer, self.config)
27 | return {
28 | "optimizer": optimizer,
29 | "lr_scheduler": scheduler,
30 | "monitor": "val_loss",
31 | }
32 |
33 | def step(self, batch, mode="train"):
34 | inputs, targets = batch
35 | preds = self.net(inputs)
36 | loss = self.criterion(preds, targets)
37 | self.log(f"{mode}_loss", loss, prog_bar=True)
38 |
39 | if mode == "val":
40 | preds_ = preds.float().detach()
41 | for metric_name in self.config["learner"]["val_metrics"]:
42 | MetricClass = self.__getattr__(f"{mode}_{metric_name}")
43 | if MetricClass is not None:
44 | self.log(
45 | f"{mode}_{metric_name}",
46 | MetricClass(preds_, targets),
47 | prog_bar=True,
48 | )
49 | return loss
50 |
51 | def training_step(self, batch, batch_idx):
52 | return self.step(batch, mode="train")
53 |
54 | def validation_step(self, batch, batch_idx):
55 | return self.step(batch, mode="val")
56 |
57 | def test_step(self, batch, batch_idx):
58 | inputs, targets = batch
59 | preds = self.net(inputs)
60 | for metric_name in self.config["learner"]["test_metrics"]:
61 | metric = self.__getattr__(f"test_{metric_name}")
62 | self.log(
63 | f"test_{metric_name}",
64 | metric(preds.float(), targets),
65 | on_step=False,
66 | on_epoch=True,
67 | prog_bar=True,
68 | )
69 |
70 | def _set_metrics(self, config):
71 | """
72 | Set TorchMetrics as attributes in a dynamical manner.
73 | For instance, `self.train_accuracy = torchmetrics.Accuracy()`
74 | """
75 | for metric_name in config["learner"]["val_metrics"]:
76 | # self.__setattr__(f"train_{metric_name}", get_metrics(metric_name, config))
77 | self.__setattr__(f"val_{metric_name}", get_metrics(metric_name, config))
78 |
79 | for metric_name in config["learner"]["test_metrics"]:
80 | self.__setattr__(f"test_{metric_name}", get_metrics(metric_name, config))
81 |
--------------------------------------------------------------------------------
/src/benchmark/reconnet/learner.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import pytorch_lightning as pl
3 | from .net import ReconNet
4 | from engine.dispatcher import get_scheduler, get_criterion, get_metrics
5 |
6 |
7 | class ReconNetLearner(pl.LightningModule):
8 | def __init__(self, config):
9 | super().__init__()
10 | y_dim = int(config["sampling_ratio"] * config["img_dim"] ** 2)
11 | self.net = ReconNet(y_dim, config["img_dim"])
12 | self.config = config
13 | self.criterion = get_criterion(config)
14 | self._set_metrics(config)
15 | self.save_hyperparameters(config)
16 |
17 | def forward(self, inputs):
18 | return self.net(inputs)
19 |
20 | def configure_optimizers(self):
21 | trainable_params = filter(lambda p: p.requires_grad, self.net.parameters())
22 | optimizer = torch.optim.Adam(
23 | trainable_params,
24 | lr=self.config["learner"]["lr"],
25 | )
26 | scheduler = get_scheduler(optimizer, self.config)
27 | return {
28 | "optimizer": optimizer,
29 | "lr_scheduler": scheduler,
30 | "monitor": "val_loss",
31 | }
32 |
33 | def step(self, batch, mode="train"):
34 | inputs, targets = batch
35 | preds = self.net(inputs)
36 | loss = self.criterion(preds, targets)
37 |
38 | self.log(f"{mode}_loss", loss, prog_bar=True)
39 |
40 | preds_ = preds.float().detach()
41 |
42 | # Log validation metrics
43 | if mode == "val":
44 | for metric_name in self.config["learner"]["val_metrics"]:
45 | metric = self.__getattr__(f"{mode}_{metric_name}")
46 | self.log(
47 | f"{mode}_{metric_name}",
48 | metric(preds_, targets),
49 | prog_bar=True,
50 | )
51 | return loss
52 |
53 | def training_step(self, batch, batch_idx):
54 | return self.step(batch, mode="train")
55 |
56 | def validation_step(self, batch, batch_idx):
57 | return self.step(batch, mode="val")
58 |
59 | def test_step(self, batch, batch_idx):
60 | inputs, targets = batch
61 | preds = self.net(inputs)
62 | for metric_name in self.config["learner"]["test_metrics"]:
63 | metric = self.__getattr__(f"test_{metric_name}")
64 | self.log(
65 | f"test_{metric_name}",
66 | metric(preds.float(), targets),
67 | prog_bar=True,
68 | on_step=False,
69 | on_epoch=True,
70 | )
71 |
72 | def _set_metrics(self, config):
73 | """
74 | Set TorchMetrics as attributes in a dynamical manner.
75 | For instance, `self.train_accuracy = torchmetrics.Accuracy()`
76 | """
77 | for metric_name in config["learner"]["val_metrics"]:
78 | # self.__setattr__(f"train_{metric_name}", get_metrics(metric_name, config))
79 | self.__setattr__(f"val_{metric_name}", get_metrics(metric_name, config))
80 |
81 | for metric_name in config["learner"]["test_metrics"]:
82 | self.__setattr__(f"test_{metric_name}", get_metrics(metric_name, config))
83 |
--------------------------------------------------------------------------------
/src/benchmark/scsnet/learner.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import pytorch_lightning as pl
3 | from .net import SCSNetInit, SCSNetDeep
4 | from engine.dispatcher import get_scheduler, get_criterion, get_metrics
5 |
6 |
7 | class SCSNetLearner(pl.LightningModule):
8 | def __init__(self, config):
9 | super().__init__()
10 | in_channels = int(config["sampling_ratio"] * 16)
11 | self.net1 = SCSNetInit(in_channels)
12 | self.net2 = SCSNetDeep()
13 | self.config = config
14 | self.criterion = get_criterion(config)
15 | self._set_metrics(config)
16 | self.save_hyperparameters(config)
17 |
18 | def forward(self, inputs):
19 | return self.net2(self.net1(inputs))
20 |
21 | def configure_optimizers(self):
22 | optimizer = torch.optim.Adam(
23 | self.parameters(),
24 | lr=self.config["learner"]["lr"],
25 | )
26 | scheduler = get_scheduler(optimizer, self.config)
27 | return {
28 | "optimizer": optimizer,
29 | "lr_scheduler": scheduler,
30 | "monitor": "val_loss",
31 | }
32 |
33 | def step(self, batch, mode="train"):
34 | inputs, targets = batch
35 |
36 | preds1 = self.net1(inputs)
37 | preds2 = self.net2(preds1)
38 |
39 | loss1 = self.criterion(preds1, targets)
40 | loss2 = self.criterion(preds2, targets)
41 | loss = loss1 + loss2
42 |
43 | preds_ = preds2.float().detach()
44 | targets_ = targets.detach()
45 |
46 | # Log validation metrics
47 | if mode == "val":
48 | self.log(f"{mode}_loss", loss2, prog_bar=True)
49 | for metric_name in self.config["learner"]["val_metrics"]:
50 | metric = self.__getattr__(f"{mode}_{metric_name}")
51 | self.log(
52 | f"{mode}_{metric_name}",
53 | metric(preds_, targets_),
54 | prog_bar=True,
55 | )
56 | return loss
57 |
58 | def training_step(self, batch, batch_idx):
59 | return self.step(batch, mode="train")
60 |
61 | def validation_step(self, batch, batch_idx):
62 | return self.step(batch, mode="val")
63 |
64 | def test_step(self, batch, batch_idx):
65 | inputs, targets = batch
66 | preds1 = self.net1(inputs)
67 | preds2 = self.net2(preds1).float()
68 | for metric_name in self.config["learner"]["test_metrics"]:
69 | metric = self.__getattr__(f"test_{metric_name}")
70 | self.log(
71 | f"test_{metric_name}",
72 | metric(preds2.float(), targets),
73 | on_step=False,
74 | on_epoch=True,
75 | prog_bar=True,
76 | )
77 |
78 | def _set_metrics(self, config):
79 | """
80 | Set TorchMetrics as attributes in a dynamical manner.
81 | For instance, `self.train_accuracy = torchmetrics.Accuracy()`
82 | """
83 | for metric_name in config["learner"]["val_metrics"]:
84 | # self.__setattr__(f"train_{metric_name}", get_metrics(metric_name, config))
85 | self.__setattr__(f"val_{metric_name}", get_metrics(metric_name, config))
86 |
87 | for metric_name in config["learner"]["test_metrics"]:
88 | self.__setattr__(f"test_{metric_name}", get_metrics(metric_name, config))
89 |
--------------------------------------------------------------------------------
/src/benchmark/reconnet/inference_simulated.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import numpy as np
4 | from pathlib import Path
5 | import warnings
6 | import cv2
7 | import scipy.ndimage
8 | from data.emnist import EMNISTDataModule
9 | from data.svhn import SVHNDataModule
10 | from data.stl10 import STL10DataModule
11 | from .learner import ReconNetLearner
12 | from utils import load_config, create_patches
13 |
14 |
15 | parser = argparse.ArgumentParser()
16 | parser.add_argument(
17 | "-d", "--dataset", type=str, required=True, help="'EMNIST', 'SVHN', or 'STL10'"
18 | )
19 | parser.add_argument(
20 | "-s",
21 | "--sampling_ratio",
22 | type=float,
23 | required=True,
24 | help="Sampling ratio in percentage",
25 | )
26 | args = parser.parse_args()
27 |
28 |
29 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
30 | warnings.simplefilter("ignore")
31 |
32 |
33 | def run():
34 | dataset = args.dataset
35 | sr = args.sampling_ratio
36 |
37 | checkpoint_folder = f"../logs/ReconNet_STL10_{int(sr * 100):04d}/best"
38 |
39 | if not os.path.exists(checkpoint_folder):
40 | run_name = os.listdir(Path(checkpoint_folder).parent)[0]
41 | checkpoint_path = (
42 | f"{Path(checkpoint_folder).parent}/{run_name}/checkpoints/last.ckpt"
43 | )
44 | message = (
45 | f"The checkpoint from the run '{run_name}' is selected by default."
46 | + "If this is not intended, change the name of the preferred checkpoint folder to 'best'."
47 | )
48 | print(message)
49 | else:
50 | checkpoint_path = f"{checkpoint_folder}/checkpoints/last.ckpt"
51 |
52 | train_config_path = os.path.join(
53 | Path(checkpoint_path).parent.parent, "hparams.yaml"
54 | )
55 | train_config = load_config(train_config_path)
56 | train_config["sampling_ratio"] = sr / 100
57 | train_config["img_dim"] = 32
58 |
59 | if dataset == "EMNIST":
60 | data_module = EMNISTDataModule(train_config)
61 | elif dataset == "SVHN":
62 | data_module = SVHNDataModule(train_config)
63 | elif dataset == "STL10":
64 | data_module = STL10DataModule(train_config, reconnet=True)
65 |
66 | learner = ReconNetLearner.load_from_checkpoint(
67 | checkpoint_path=checkpoint_path, config=train_config, strict=False
68 | )
69 |
70 | message = f"Inference: ReconNet on {dataset} dataset. Sampling ratio = {train_config['sampling_ratio']}"
71 | print(message)
72 |
73 | data_module.setup()
74 | ds = data_module.test_dataset
75 |
76 | directory = f"../inference_images/ReconNet/{dataset}/{int(sr * 100):04d}"
77 | os.makedirs(directory, exist_ok=True)
78 |
79 | if dataset != "STL10":
80 | for i in np.linspace(0, len(ds) - 1, 30, dtype=int):
81 | input = ds[i][0].unsqueeze(0)
82 | out = learner(input).squeeze().squeeze().detach().numpy()
83 | out = scipy.ndimage.zoom(out, 8, order=0, mode="nearest")
84 | cv2.imwrite(f"{directory}/{i}.png", out * 255)
85 |
86 | else:
87 | if not os.path.exists("../input/STL10/test_images_32x32"):
88 | create_patches("../input/STL10", size=32)
89 | for i in range(10):
90 | combined = np.zeros((96, 96), dtype=int)
91 | for j in range(9):
92 | input = ds[i * 9 + j][0].unsqueeze(0)
93 | out = learner(input).squeeze().squeeze().detach().numpy()
94 | out = (out * 255).astype(int)
95 | x = j // 3 * 32
96 | y = j % 3 * 32
97 | combined[x : x + 32, y : y + 32] = out
98 | combined = scipy.ndimage.zoom(combined, 8, order=0, mode="nearest")
99 | cv2.imwrite(f"{directory}/{i}.png", combined)
100 |
101 | print("Done.")
102 |
103 |
104 | if __name__ == "__main__":
105 | run()
106 |
--------------------------------------------------------------------------------
/src/data/svhn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import DataLoader
3 | from torch.utils.data.sampler import SubsetRandomSampler
4 | import torchvision
5 | from torchvision import transforms
6 | import pytorch_lightning as pl
7 | from .base_dataset import BaseDataset
8 |
9 |
10 | class SVHNDataset(BaseDataset):
11 | def __init__(self, sampling_ratio: float, bcs: bool, tfms=None, train=True):
12 | super().__init__(sampling_ratio=sampling_ratio, bcs=bcs)
13 | self.data = torchvision.datasets.SVHN(
14 | "../input/SVHN",
15 | split="train" if train else "test",
16 | transform=tfms,
17 | download=True,
18 | )
19 |
20 | def __getitem__(self, idx):
21 | image, _ = self.data[idx]
22 | image_ = image.unsqueeze(dim=1)
23 | y = self.cs_operator(image_)
24 | return y.squeeze(dim=0), image
25 |
26 | def __len__(self):
27 | return len(self.data)
28 |
29 |
30 | class SVHNDataModule(pl.LightningDataModule):
31 | def __init__(self, config):
32 | super().__init__()
33 | self.config = config
34 | self.dm_config = config["data_module"]
35 |
36 | def setup(self, stage=None):
37 | train_tfms = transforms.Compose(
38 | [
39 | transforms.Resize(32),
40 | transforms.RandomHorizontalFlip(p=0.5),
41 | transforms.ColorJitter(brightness=0.2, contrast=0.2),
42 | transforms.Grayscale(),
43 | transforms.ToTensor(),
44 | ]
45 | )
46 | val_tfms = transforms.Compose(
47 | [
48 | transforms.Resize(32),
49 | transforms.Grayscale(),
50 | transforms.ToTensor(),
51 | ]
52 | )
53 |
54 | self.train_dataset = SVHNDataset(
55 | sampling_ratio=self.config["sampling_ratio"],
56 | bcs=self.config["bcs"],
57 | tfms=train_tfms,
58 | train=True,
59 | )
60 |
61 | self.val_dataset = SVHNDataset(
62 | sampling_ratio=self.config["sampling_ratio"],
63 | bcs=self.config["bcs"],
64 | tfms=val_tfms,
65 | train=True,
66 | )
67 |
68 | self.test_dataset = SVHNDataset(
69 | sampling_ratio=self.config["sampling_ratio"],
70 | bcs=self.config["bcs"],
71 | tfms=val_tfms,
72 | train=False,
73 | )
74 |
75 | dataset_size = len(self.train_dataset)
76 | indices = torch.randperm(dataset_size)
77 | split = int(self.dm_config["val_percent"] * dataset_size)
78 | self.train_idx, self.valid_idx = indices[split:], indices[:split]
79 |
80 | def train_dataloader(self):
81 | train_sampler = SubsetRandomSampler(self.train_idx)
82 | return DataLoader(
83 | self.train_dataset,
84 | batch_size=self.dm_config["batch_size"],
85 | sampler=train_sampler,
86 | num_workers=self.dm_config["num_workers"],
87 | )
88 |
89 | def val_dataloader(self):
90 | val_sampler = SubsetRandomSampler(self.valid_idx)
91 | return DataLoader(
92 | self.val_dataset,
93 | batch_size=self.dm_config["batch_size"],
94 | sampler=val_sampler,
95 | num_workers=self.dm_config["num_workers"],
96 | )
97 |
98 | def test_dataloader(self):
99 | return DataLoader(
100 | self.test_dataset,
101 | batch_size=self.dm_config["batch_size"],
102 | num_workers=self.dm_config["num_workers"],
103 | )
104 |
105 | def predict_dataloader(self):
106 | return DataLoader(
107 | self.test_dataset,
108 | batch_size=self.dm_config["batch_size"],
109 | num_workers=self.dm_config["num_workers"],
110 | )
111 |
--------------------------------------------------------------------------------
/src/data/emnist.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import DataLoader
3 | from torch.utils.data.sampler import SubsetRandomSampler
4 | import torchvision
5 | from torchvision import transforms
6 | import pytorch_lightning as pl
7 | from .base_dataset import BaseDataset
8 |
9 |
10 | class EMNISTDataset(BaseDataset):
11 | def __init__(self, sampling_ratio: float, bcs: bool, tfms=None, train=True):
12 | super().__init__(sampling_ratio=sampling_ratio, bcs=bcs)
13 | self.data = torchvision.datasets.EMNIST(
14 | "../input",
15 | split="letters", # we use letters only
16 | train=train,
17 | transform=tfms,
18 | download=True,
19 | )
20 |
21 | def __getitem__(self, idx):
22 | image, _ = self.data[idx]
23 | image_ = image.unsqueeze(dim=1)
24 | y = self.cs_operator(image_)
25 | return y.squeeze(dim=0), image
26 |
27 | def __len__(self):
28 | return len(self.data)
29 |
30 |
31 | class EMNISTDataModule(pl.LightningDataModule):
32 | def __init__(self, config):
33 | super().__init__()
34 | self.config = config
35 | self.dm_config = config["data_module"]
36 |
37 | def setup(self, stage=None):
38 | train_tfms = transforms.Compose(
39 | [
40 | transforms.Resize(32),
41 | transforms.RandomHorizontalFlip(p=0.5),
42 | transforms.ColorJitter(brightness=0.2, contrast=0.2),
43 | transforms.Grayscale(),
44 | transforms.ToTensor(),
45 | ]
46 | )
47 | val_tfms = transforms.Compose(
48 | [
49 | transforms.Resize(32),
50 | transforms.Grayscale(),
51 | transforms.ToTensor(),
52 | ]
53 | )
54 |
55 | self.train_dataset = EMNISTDataset(
56 | sampling_ratio=self.config["sampling_ratio"],
57 | bcs=self.config["bcs"],
58 | tfms=train_tfms,
59 | train=True,
60 | )
61 |
62 | self.val_dataset = EMNISTDataset(
63 | sampling_ratio=self.config["sampling_ratio"],
64 | bcs=self.config["bcs"],
65 | tfms=val_tfms,
66 | train=True,
67 | )
68 |
69 | self.test_dataset = EMNISTDataset(
70 | sampling_ratio=self.config["sampling_ratio"],
71 | bcs=self.config["bcs"],
72 | tfms=val_tfms,
73 | train=False,
74 | )
75 |
76 | dataset_size = len(self.train_dataset)
77 | indices = torch.randperm(dataset_size)
78 | split = int(self.dm_config["val_percent"] * dataset_size)
79 | self.train_idx, self.valid_idx = indices[split:], indices[:split]
80 |
81 | def train_dataloader(self):
82 | train_sampler = SubsetRandomSampler(self.train_idx)
83 | return DataLoader(
84 | self.train_dataset,
85 | batch_size=self.dm_config["batch_size"],
86 | sampler=train_sampler,
87 | num_workers=self.dm_config["num_workers"],
88 | )
89 |
90 | def val_dataloader(self):
91 | val_sampler = SubsetRandomSampler(self.valid_idx)
92 | return DataLoader(
93 | self.val_dataset,
94 | batch_size=self.dm_config["batch_size"],
95 | sampler=val_sampler,
96 | num_workers=self.dm_config["num_workers"],
97 | )
98 |
99 | def test_dataloader(self):
100 | return DataLoader(
101 | self.test_dataset,
102 | batch_size=self.dm_config["batch_size"],
103 | num_workers=self.dm_config["num_workers"],
104 | )
105 |
106 | def predict_dataloader(self):
107 | return DataLoader(
108 | self.test_dataset,
109 | batch_size=self.dm_config["batch_size"],
110 | num_workers=self.dm_config["num_workers"],
111 | )
112 |
--------------------------------------------------------------------------------
/src/benchmark/reconnet/inference_spi.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 | import time
4 | import argparse
5 | import warnings
6 | import numpy as np
7 | import cv2
8 | import scipy.ndimage
9 | import scipy.io
10 | import math
11 | import torch
12 | import pytorch_lightning as pl
13 | from .learner import ReconNetLearner
14 | from utils import voltage2pixel, load_config
15 |
16 |
17 | parser = argparse.ArgumentParser()
18 | parser.add_argument(
19 | "-s",
20 | "--sampling_ratio",
21 | type=float,
22 | required=True,
23 | help="Sampling ratio in percentage",
24 | )
25 | args = parser.parse_args()
26 |
27 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
28 | warnings.simplefilter("ignore")
29 |
30 |
31 | def setup():
32 | inference_config = load_config("../config/inference_config.yaml")
33 | sr = args.sampling_ratio
34 |
35 | checkpoint_folder = f"../logs/ReconNet_STL10_{int(sr * 100):04d}/best"
36 |
37 | if not os.path.exists(checkpoint_folder):
38 | run_name = os.listdir(Path(checkpoint_folder).parent)[-1]
39 | checkpoint_path = (
40 | f"{Path(checkpoint_folder).parent}/{run_name}/checkpoints/last.ckpt"
41 | )
42 | message = (
43 | f"The checkpoint from the run '{run_name}' is selected by default. "
44 | + "If this is not intended, change the name of the preferred checkpoint folder to 'best'."
45 | )
46 | print(message)
47 | else:
48 | checkpoint_path = f"{checkpoint_folder}/checkpoints/last.ckpt"
49 |
50 | learner = ReconNetLearner.load_from_checkpoint(checkpoint_path=checkpoint_path)
51 |
52 | trainer = pl.Trainer(
53 | gpus=1 if inference_config["gpu"] else 0,
54 | logger=False,
55 | default_root_dir="../",
56 | )
57 | return learner, trainer
58 |
59 |
60 | class RealDataset:
61 | def __init__(self, sampling_ratio, inference_config):
62 | self.real_data = inference_config["real_data"]
63 | self.phi = np.load(inference_config["measurement_matrix"])
64 | self.c = int(sampling_ratio / 100 * 16)
65 |
66 | def __getitem__(self, idx):
67 | real_data = self.real_data[idx]
68 | path = os.path.join("../inference_input", real_data["filename"])
69 | y_input = scipy.io.loadmat(path)["y"]
70 |
71 | y_input = y_input[
72 | np.mod(np.arange(len(y_input)), len(y_input) // 64) < self.c
73 | ] # discard extra measurements
74 |
75 | y_input = torch.FloatTensor(y_input).permute(1, 0)
76 | y_input -= y_input.min()
77 | y_input /= real_data["max"]
78 |
79 | # Permute is necessary because during sampling, we used "channel-last" format.
80 | # Hence, we need to permute it to become channel-first to match PyTorch "channel-first" format
81 | y_input = y_input.view(-1, self.c)
82 | y_input = y_input.permute(1, 0).contiguous()
83 | y_input = y_input.view(
84 | -1, int(math.sqrt(y_input.shape[-1])), int(math.sqrt(y_input.shape[-1]))
85 | )
86 |
87 | y_input = voltage2pixel(
88 | y_input, self.phi[: self.c], real_data["min"], real_data["max"]
89 | )
90 | return y_input
91 |
92 | def __len__(self):
93 | return len(self.real_data)
94 |
95 |
96 | def deploy(learner):
97 | """Real experimental data"""
98 | inference_config = load_config("../config/inference_config.yaml")
99 | sr = args.sampling_ratio
100 | directory = f"../inference_images/ReconNet/SPI/{int(sr * 100):04d}"
101 | os.makedirs(directory, exist_ok=True)
102 | real_dataset = RealDataset(sr, inference_config)
103 | for x in real_dataset:
104 | prediction = learner(x.unsqueeze(0))
105 | prediction = prediction.squeeze().squeeze().cpu().detach().numpy()
106 | prediction = scipy.ndimage.zoom(prediction, 4, order=0, mode="nearest")
107 | cv2.imwrite(f"{directory}/{time.time()}.png", prediction * 255)
108 | print("Finished reconstructing SPI images.")
109 |
110 |
111 | if __name__ == "__main__":
112 | learner, trainer = setup()
113 | deploy(learner)
114 |
--------------------------------------------------------------------------------
/src/benchmark/scsnet/inference_spi.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 | import time
4 | import warnings
5 | import numpy as np
6 | import cv2
7 | import argparse
8 | import scipy.ndimage
9 | import scipy.io
10 | import math
11 | import torch
12 | import pytorch_lightning as pl
13 | from .learner import SCSNetLearner
14 | from utils import voltage2pixel, load_config
15 |
16 |
17 | parser = argparse.ArgumentParser()
18 | parser.add_argument(
19 | "-s",
20 | "--sampling_ratio",
21 | type=float,
22 | required=True,
23 | help="Sampling ratio in percentage",
24 | )
25 | args = parser.parse_args()
26 |
27 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
28 | warnings.simplefilter("ignore")
29 |
30 |
31 | def setup():
32 | inference_config = load_config("../config/inference_config.yaml")
33 | # ds = inference_config["dataset"]
34 | sr = inference_config["sampling_ratio"]
35 | checkpoint_folder = f"../logs/SCSNet_STL10_{int(sr * 100):04d}/best"
36 |
37 | if not os.path.exists(checkpoint_folder):
38 | run_name = os.listdir(Path(checkpoint_folder).parent)[0]
39 | checkpoint_path = (
40 | f"{Path(checkpoint_folder).parent}/{run_name}/checkpoints/last.ckpt"
41 | )
42 | print(
43 | f"The checkpoint from the run '{run_name}' is selected by default. \
44 | If this is not intended, change the name of the preferred checkpoint folder to 'best'."
45 | )
46 | else:
47 | checkpoint_path = f"{checkpoint_folder}/checkpoints/last.ckpt"
48 |
49 | learner = SCSNetLearner.load_from_checkpoint(checkpoint_path=checkpoint_path)
50 |
51 | trainer = pl.Trainer(
52 | gpus=1 if inference_config["gpu"] else 0,
53 | logger=False,
54 | default_root_dir="../",
55 | )
56 | return learner, trainer
57 |
58 |
59 | class RealDataset:
60 | def __init__(self, sampling_ratio, inference_config):
61 | self.real_data = inference_config["real_data"]
62 | self.phi = np.load(inference_config["measurement_matrix"])
63 | self.c = int(sampling_ratio / 100 * 16)
64 |
65 | def __getitem__(self, idx):
66 | real_data = self.real_data[idx]
67 | path = os.path.join("../inference_input", real_data["filename"])
68 | y_input = scipy.io.loadmat(path)["y"]
69 |
70 | y_input = y_input[
71 | np.mod(np.arange(len(y_input)), len(y_input) // 64) < self.c
72 | ] # discard extra measurements
73 |
74 | y_input = torch.FloatTensor(y_input).permute(1, 0)
75 | y_input -= y_input.min()
76 | y_input /= real_data["max"]
77 |
78 | # Permute is necessary because during sampling, we used "channel-last" format.
79 | # Hence, we need to permute it to become channel-first to match PyTorch "channel-first" format
80 | y_input = y_input.view(-1, self.c)
81 | y_input = y_input.permute(1, 0).contiguous()
82 | y_input = y_input.view(
83 | -1, int(math.sqrt(y_input.shape[-1])), int(math.sqrt(y_input.shape[-1]))
84 | )
85 |
86 | y_input = voltage2pixel(
87 | y_input, self.phi[: self.c], real_data["min"], real_data["max"]
88 | )
89 | return y_input
90 |
91 | def __len__(self):
92 | return len(self.real_data)
93 |
94 |
95 | def predict_one():
96 | # TODO: for standard dataset test set.
97 | return
98 |
99 |
100 | def deploy(learner):
101 | """Real experimental data"""
102 | inference_config = load_config("../config/inference_config.yaml")
103 | sr = args.sampling_ratio
104 | directory = f"../inference_images/SCSNet/SPI/{int(sr * 100):04d}"
105 | os.makedirs(directory, exist_ok=True)
106 | real_dataset = RealDataset(sr, inference_config)
107 | for x in real_dataset:
108 | prediction = learner(x.unsqueeze(0))
109 | prediction = prediction.squeeze().squeeze().cpu().detach().numpy()
110 | prediction = scipy.ndimage.zoom(prediction, 4, order=0, mode="nearest")
111 | cv2.imwrite(f"{directory}/{time.time()}.png", prediction * 255)
112 | print("Finished reconstructing SPI images.")
113 |
114 |
115 | if __name__ == "__main__":
116 | learner, trainer = setup()
117 | deploy(learner)
118 |
--------------------------------------------------------------------------------
/src/inference_spi.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 | import time
4 | import argparse
5 | import warnings
6 | import numpy as np
7 | import cv2
8 | import scipy.ndimage
9 | import scipy.io
10 | import math
11 | import torch
12 | import pytorch_lightning as pl
13 | from engine.learner import BCSUNetLearner
14 | from utils import voltage2pixel, load_config
15 |
16 |
17 | parser = argparse.ArgumentParser()
18 | parser.add_argument(
19 | "-d", "--dataset", type=str, required=True, help="'EMNIST', 'SVHN', or 'STL10'"
20 | )
21 | parser.add_argument(
22 | "-s",
23 | "--sampling_ratio",
24 | type=float,
25 | required=True,
26 | help="Sampling ratio in percentage",
27 | )
28 | args = parser.parse_args()
29 |
30 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
31 | warnings.simplefilter("ignore")
32 |
33 |
34 | def setup():
35 | inference_config = load_config("../config/inference_config.yaml")
36 | ds = args.dataset
37 | sr = args.sampling_ratio
38 |
39 | checkpoint_folder = f"../logs/BCSUNet_{ds}_{int(sr * 100):04d}/best"
40 |
41 | if not os.path.exists(checkpoint_folder):
42 | run_name = os.listdir(Path(checkpoint_folder).parent)[0]
43 | checkpoint_path = (
44 | f"{Path(checkpoint_folder).parent}/{run_name}/checkpoints/best.ckpt"
45 | )
46 | message = (
47 | f"The checkpoint from the run '{run_name}' is selected by default."
48 | + "If this is not intended, change the name of the preferred checkpoint folder to 'best'."
49 | )
50 | print(message)
51 | else:
52 | checkpoint_path = f"{checkpoint_folder}/checkpoints/best.ckpt"
53 |
54 | learner = BCSUNetLearner.load_from_checkpoint(checkpoint_path=checkpoint_path)
55 |
56 | trainer = pl.Trainer(
57 | gpus=1 if inference_config["gpu"] else 0,
58 | logger=False,
59 | default_root_dir="../",
60 | )
61 | return learner, trainer
62 |
63 |
64 | class RealDataset:
65 | def __init__(self, sampling_ratio, inference_config):
66 | self.real_data = inference_config["real_data"]
67 | self.phi = np.load(inference_config["measurement_matrix"])
68 | self.c = int(sampling_ratio / 100 * 16)
69 |
70 | def __getitem__(self, idx):
71 | real_data = self.real_data[idx]
72 | path = os.path.join("../inference_input", real_data["filename"])
73 | y_input = scipy.io.loadmat(path)["y"]
74 |
75 | y_input = y_input[
76 | np.mod(np.arange(len(y_input)), len(y_input) // 64) < self.c
77 | ] # discard extra measurements
78 |
79 | y_input = torch.FloatTensor(y_input).permute(1, 0)
80 | y_input -= y_input.min()
81 | y_input /= real_data["max"]
82 |
83 | # Permute is necessary because during sampling, we used "channel-last" format.
84 | # Hence, we need to permute it to become channel-first to match PyTorch "channel-first" format
85 | y_input = y_input.view(-1, self.c)
86 | y_input = y_input.permute(1, 0).contiguous()
87 | y_input = y_input.view(
88 | -1, int(math.sqrt(y_input.shape[-1])), int(math.sqrt(y_input.shape[-1]))
89 | )
90 |
91 | y_input = voltage2pixel(
92 | y_input, self.phi[: self.c], real_data["min"], real_data["max"]
93 | )
94 | return y_input
95 |
96 | def __len__(self):
97 | return len(self.real_data)
98 |
99 |
100 | def deploy(learner):
101 | """Real experimental data"""
102 | inference_config = load_config("../config/inference_config.yaml")
103 | sr = args.sampling_ratio
104 | real_dataset = RealDataset(sr, inference_config)
105 | save_dir = f"../inference_images/BCSUNet/SPI/{int(sr * 100):04d}_{args.dataset}"
106 | os.makedirs(save_dir, exist_ok=True)
107 |
108 | for x in real_dataset:
109 | prediction = learner(x.unsqueeze(0))
110 | prediction = prediction.squeeze().squeeze().cpu().detach().numpy()
111 | prediction = scipy.ndimage.zoom(prediction, 8, order=0, mode="nearest")
112 | cv2.imwrite(f"{save_dir}/{time.time()}.png", prediction * 255)
113 |
114 | print("Finished reconstructing SPI images.")
115 |
116 |
117 | if __name__ == "__main__":
118 | learner, trainer = setup()
119 | deploy(learner)
120 |
--------------------------------------------------------------------------------
/src/model/spectral_norm.py:
--------------------------------------------------------------------------------
1 | """
2 | Implementation of spectral normalization for GANs.
3 | """
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 |
9 | class SpectralNorm:
10 | r"""
11 | Spectral Normalization for GANs (Miyato 2018).
12 |
13 | Inheritable class for performing spectral normalization of weights,
14 | as approximated using power iteration.
15 |
16 | Details: See Algorithm 1 of Appendix A (Miyato 2018).
17 |
18 | Attributes:
19 | n_dim (int): Number of dimensions.
20 | num_iters (int): Number of iterations for power iter.
21 | eps (float): Epsilon for zero division tolerance when normalizing.
22 | """
23 |
24 | def __init__(self, n_dim, num_iters=1, eps=1e-12):
25 | self.num_iters = num_iters
26 | self.eps = eps
27 |
28 | # Register a singular vector for each sigma
29 | self.register_buffer("sn_u", torch.randn(1, n_dim))
30 | self.register_buffer("sn_sigma", torch.ones(1))
31 |
32 | @property
33 | def u(self):
34 | return getattr(self, "sn_u")
35 |
36 | @property
37 | def sigma(self):
38 | return getattr(self, "sn_sigma")
39 |
40 | def _power_iteration(self, W, u, num_iters, eps=1e-12):
41 | with torch.no_grad():
42 | for _ in range(num_iters):
43 | v = F.normalize(torch.matmul(u, W), eps=eps)
44 | u = F.normalize(torch.matmul(v, W.t()), eps=eps)
45 |
46 | # Note: must have gradients, otherwise weights do not get updated!
47 | sigma = torch.mm(u, torch.mm(W, v.t()))
48 |
49 | return sigma, u, v
50 |
51 | def sn_weights(self):
52 | r"""
53 | Spectrally normalize current weights of the layer.
54 | """
55 | W = self.weight.view(self.weight.shape[0], -1)
56 |
57 | # Power iteration
58 | sigma, u, v = self._power_iteration(
59 | W=W, u=self.u, num_iters=self.num_iters, eps=self.eps
60 | )
61 |
62 | # Update only during training
63 | if self.training:
64 | with torch.no_grad():
65 | self.sigma[:] = sigma
66 | self.u[:] = u
67 |
68 | return self.weight / sigma
69 |
70 |
71 | class SNConv2d(nn.Conv2d, SpectralNorm):
72 | r"""
73 | Spectrally normalized layer for Conv2d.
74 |
75 | Attributes:
76 | in_channels (int): Input channel dimension.
77 | out_channels (int): Output channel dimensions.
78 | """
79 |
80 | def __init__(self, in_channels, out_channels, *args, **kwargs):
81 | nn.Conv2d.__init__(self, in_channels, out_channels, *args, **kwargs)
82 |
83 | SpectralNorm.__init__(
84 | self, n_dim=out_channels, num_iters=kwargs.get("num_iters", 1)
85 | )
86 |
87 | def forward(self, x):
88 | return F.conv2d(
89 | input=x,
90 | weight=self.sn_weights(),
91 | bias=self.bias,
92 | stride=self.stride,
93 | padding=self.padding,
94 | dilation=self.dilation,
95 | groups=self.groups,
96 | )
97 |
98 |
99 | class SNLinear(nn.Linear, SpectralNorm):
100 | r"""
101 | Spectrally normalized layer for Linear.
102 |
103 | Attributes:
104 | in_features (int): Input feature dimensions.
105 | out_features (int): Output feature dimensions.
106 | """
107 |
108 | def __init__(self, in_features, out_features, *args, **kwargs):
109 | nn.Linear.__init__(self, in_features, out_features, *args, **kwargs)
110 |
111 | SpectralNorm.__init__(
112 | self, n_dim=out_features, num_iters=kwargs.get("num_iters", 1)
113 | )
114 |
115 | def forward(self, x):
116 | return F.linear(input=x, weight=self.sn_weights(), bias=self.bias)
117 |
118 |
119 | class SNEmbedding(nn.Embedding, SpectralNorm):
120 | r"""
121 | Spectrally normalized layer for Embedding.
122 |
123 | Attributes:
124 | num_embeddings (int): Number of embeddings.
125 | embedding_dim (int): Dimensions of each embedding vector
126 | """
127 |
128 | def __init__(self, num_embeddings, embedding_dim, *args, **kwargs):
129 | nn.Embedding.__init__(self, num_embeddings, embedding_dim, *args, **kwargs)
130 |
131 | SpectralNorm.__init__(self, n_dim=num_embeddings)
132 |
133 | def forward(self, x):
134 | return F.embedding(input=x, weight=self.sn_weights())
135 |
--------------------------------------------------------------------------------
/src/model/upsamplenet.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | from torch.nn import functional as F
3 | from model.layers import SNConv2d
4 |
5 |
6 | class ReshapeNet(nn.Module):
7 | """The "initial reconstruction network" of SCSNet"""
8 |
9 | def __init__(self, in_channels, block_size=4):
10 | super().__init__()
11 | self.block_size = block_size
12 | self.conv = nn.Conv2d(in_channels, block_size ** 2, kernel_size=1)
13 |
14 | def forward(self, x):
15 | x = self.conv(x)
16 | out = self._permute(x)
17 | return out
18 |
19 | def _permute(self, x):
20 | B, C, H, W = x.shape
21 | x = x.permute(0, 2, 3, 1)
22 | x = x.view(B, H, W, self.block_size, self.block_size)
23 | x = x.permute(0, 1, 3, 2, 4).contiguous()
24 | out = x.view(-1, 1, H * self.block_size, W * self.block_size)
25 | return out
26 |
27 |
28 | class UpsampleNet(nn.Module):
29 | def __init__(self, sampling_ratio, upsamplenet_config):
30 | super().__init__()
31 | kernel_size = 4
32 | first_out_channels = int(sampling_ratio * kernel_size ** 2)
33 | config = upsamplenet_config
34 |
35 | self.up1 = UpResBlock(
36 | in_channels=first_out_channels,
37 | out_channels=config["out_channels_1"],
38 | middle_channels=None,
39 | upsample=True,
40 | use_transpose_conv=config["use_transpose_conv"],
41 | spectral_norm=config["spectral_norm"],
42 | )
43 |
44 | self.up2 = UpResBlock(
45 | in_channels=config["out_channels_1"],
46 | out_channels=config["out_channels_2"],
47 | middle_channels=None,
48 | upsample=True,
49 | use_transpose_conv=config["use_transpose_conv"],
50 | spectral_norm=config["spectral_norm"],
51 | )
52 |
53 | def forward(self, x):
54 | x = self.up1(x)
55 | out = self.up2(x) # passed to UNet
56 | return out
57 |
58 |
59 | class UpResBlock(nn.Module):
60 | def __init__(
61 | self,
62 | in_channels,
63 | out_channels,
64 | middle_channels=None,
65 | upsample=True,
66 | use_transpose_conv=False,
67 | norm_type="instance",
68 | spectral_norm=True,
69 | init_type="xavier",
70 | ):
71 |
72 | super().__init__()
73 | self.upsample = upsample
74 | self.use_transpose_conv = use_transpose_conv
75 |
76 | if middle_channels is None:
77 | middle_channels = out_channels
78 |
79 | if use_transpose_conv:
80 | assert upsample is True
81 | self.conv1 = nn.ConvTranspose2d(
82 | in_channels,
83 | middle_channels,
84 | kernel_size=2,
85 | stride=2,
86 | padding=1,
87 | bias=False,
88 | )
89 | self.conv2 = nn.ConvTranspose2d(
90 | middle_channels,
91 | out_channels,
92 | kernel_size=2,
93 | stride=2,
94 | padding=1,
95 | bias=False,
96 | )
97 |
98 | else: # if transpose conv is not used.
99 | # The `_residual_block` method will decide whether or not it upsamples depending on `upsample == True/False`
100 | conv = SNConv2d if spectral_norm else nn.Conv2d
101 | self.conv1 = conv(
102 | in_channels,
103 | middle_channels,
104 | kernel_size=3,
105 | stride=1,
106 | padding=1,
107 | bias=False,
108 | )
109 | self.conv2 = conv(
110 | middle_channels,
111 | out_channels,
112 | kernel_size=3,
113 | stride=1,
114 | padding=1,
115 | bias=False,
116 | )
117 |
118 | self.bn1 = nn.BatchNorm2d(middle_channels)
119 | self.bn2 = nn.BatchNorm2d(out_channels)
120 |
121 | self.relu = nn.ReLU(inplace=True)
122 |
123 | def _upsample(self, x, conv_layer):
124 | if self.use_transpose_conv:
125 | return conv_layer(x)
126 | else:
127 | return conv_layer(
128 | F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=False)
129 | )
130 |
131 | def _residual_block(self, x):
132 | x = self._upsample(x, self.conv1) if self.upsample else self.conv1(x)
133 | x = self.bn1(x)
134 | x = self.relu(x)
135 | x = self.conv2(x)
136 | x = self.bn2(x)
137 | out = self.relu(x)
138 | return out
139 |
140 | def _shortcut(self, x):
141 | return self._upsample(x, self.conv1) if self.upsample else self.conv1(x)
142 |
143 | def forward(self, x):
144 | return self._residual_block(x) + self._shortcut(x)
145 |
--------------------------------------------------------------------------------
/src/data/stl10.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import torch
4 | from torch.utils.data import DataLoader
5 | from torch.utils.data.sampler import SubsetRandomSampler
6 | import torchvision
7 | from torchvision import transforms
8 | import pytorch_lightning as pl
9 | from .base_dataset import BaseDataset
10 | from utils import create_patches
11 |
12 |
13 | class STL10Dataset(BaseDataset):
14 | def __init__(self, sampling_ratio: float, bcs: bool, tfms=None, train=True):
15 | super().__init__(sampling_ratio=sampling_ratio, bcs=bcs)
16 | self.data = torchvision.datasets.STL10(
17 | "../input/STL10",
18 | split="unlabeled" if train else "test",
19 | transform=tfms,
20 | download=True,
21 | )
22 |
23 | def __getitem__(self, idx):
24 | image, _ = self.data[idx]
25 | image_ = image.unsqueeze(dim=1)
26 | y = self.cs_operator(image_)
27 | return y.squeeze(dim=0), image
28 |
29 | def __len__(self):
30 | return len(self.data)
31 |
32 |
33 | class STL10ReconnetTestDataset(BaseDataset):
34 | """
35 | The test dataset of ReconNet is unique because we need to reconstruct image patches and combine.
36 | This is because ReconNet can only be used on small images.
37 | To reproduce this, crop the 96x96 STL10 images into 32x32.
38 | """
39 |
40 | def __init__(self, sampling_ratio: float, bcs: bool):
41 | super().__init__(sampling_ratio=sampling_ratio, bcs=bcs)
42 | self.root_dir = "../input/STL10"
43 | self.data_dir = os.path.join(self.root_dir, "test_images_32x32")
44 | if not os.path.exists(self.data_dir):
45 | create_patches(self.root_dir, size=32)
46 | self.filenames = os.listdir(self.data_dir)
47 |
48 | def __getitem__(self, idx):
49 | image = cv2.imread(
50 | os.path.join(self.data_dir, self.filenames[idx]),
51 | cv2.IMREAD_GRAYSCALE,
52 | )
53 | image = torch.tensor(image) / 255.0
54 | image = image.unsqueeze(dim=0)
55 | image_ = image.unsqueeze(dim=1)
56 | y = self.cs_operator(image_)
57 | return y.squeeze(dim=0), image
58 |
59 | def __len__(self):
60 | return len(self.filenames)
61 |
62 |
63 | class STL10DataModule(pl.LightningDataModule):
64 | def __init__(self, config, reconnet=False):
65 | super().__init__()
66 | self.config = config
67 | self.dm_config = config["data_module"]
68 | self.reconnet = reconnet # whether or not ReconNet is the architecture.
69 |
70 | def setup(self, stage=None):
71 | train_tfms_list = [transforms.RandomCrop(32)] if self.reconnet else []
72 | train_tfms_list += [
73 | transforms.RandomHorizontalFlip(p=0.5),
74 | transforms.ColorJitter(brightness=0.2, contrast=0.2),
75 | transforms.Grayscale(),
76 | transforms.ToTensor(),
77 | ]
78 | train_tfms = transforms.Compose(train_tfms_list)
79 |
80 | val_tfms_list = [transforms.CenterCrop(32)] if self.reconnet else []
81 | val_tfms_list += [
82 | transforms.Grayscale(),
83 | transforms.ToTensor(),
84 | ]
85 | val_tfms = transforms.Compose(val_tfms_list)
86 |
87 | self.train_dataset = STL10Dataset(
88 | sampling_ratio=self.config["sampling_ratio"],
89 | bcs=self.config["bcs"],
90 | tfms=train_tfms,
91 | train=True,
92 | )
93 |
94 | self.val_dataset = STL10Dataset(
95 | sampling_ratio=self.config["sampling_ratio"],
96 | bcs=self.config["bcs"],
97 | tfms=val_tfms,
98 | train=True,
99 | )
100 |
101 | if self.reconnet:
102 | self.test_dataset = STL10ReconnetTestDataset(
103 | sampling_ratio=self.config["sampling_ratio"],
104 | bcs=self.config["bcs"],
105 | )
106 | else:
107 | self.test_dataset = STL10Dataset(
108 | sampling_ratio=self.config["sampling_ratio"],
109 | bcs=self.config["bcs"],
110 | tfms=val_tfms,
111 | train=False,
112 | )
113 |
114 | dataset_size = len(self.train_dataset)
115 | indices = torch.randperm(dataset_size)
116 | split = int(self.dm_config["val_percent"] * dataset_size)
117 | self.train_idx, self.valid_idx = indices[split:], indices[:split]
118 |
119 | def train_dataloader(self):
120 | train_sampler = SubsetRandomSampler(self.train_idx)
121 | return DataLoader(
122 | self.train_dataset,
123 | batch_size=self.dm_config["batch_size"],
124 | sampler=train_sampler,
125 | num_workers=self.dm_config["num_workers"],
126 | )
127 |
128 | def val_dataloader(self):
129 | val_sampler = SubsetRandomSampler(self.valid_idx)
130 | return DataLoader(
131 | self.val_dataset,
132 | batch_size=self.dm_config["batch_size"],
133 | sampler=val_sampler,
134 | num_workers=self.dm_config["num_workers"],
135 | )
136 |
137 | def test_dataloader(self):
138 | return DataLoader(
139 | self.test_dataset,
140 | batch_size=self.dm_config["batch_size"],
141 | num_workers=self.dm_config["num_workers"],
142 | )
143 |
144 | def predict_dataloader(self):
145 | if self.reconnet:
146 | batch_size = 9 * 12 # 9 image patches per 96x96 image
147 | else:
148 | batch_size = self.dm_config["batch_size"]
149 |
150 | return DataLoader(
151 | self.test_dataset,
152 | batch_size=batch_size,
153 | num_workers=self.dm_config["num_workers"],
154 | )
155 |
--------------------------------------------------------------------------------
/src/model/layers.py:
--------------------------------------------------------------------------------
1 | """
2 | Script for building specific layers needed by GAN architecture:
3 | https://github.com/kwotsin/mimicry/tree/master/torch_mimicry/modules/layers.py
4 | """
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 | from model import spectral_norm
10 |
11 |
12 | class SelfAttention(nn.Module):
13 | """
14 | Self-attention layer based on version used in BigGAN code:
15 | https://github.com/ajbrock/BigGAN-PyTorch/blob/master/layers.py
16 | """
17 |
18 | def __init__(self, num_feat, spectral_norm=True):
19 | super().__init__()
20 | self.num_feat = num_feat
21 | self.spectral_norm = spectral_norm
22 |
23 | if self.spectral_norm:
24 | self.theta = SNConv2d(
25 | self.num_feat, self.num_feat >> 3, 1, 1, padding=0, bias=False
26 | )
27 | self.phi = SNConv2d(
28 | self.num_feat, self.num_feat >> 3, 1, 1, padding=0, bias=False
29 | )
30 | self.g = SNConv2d(
31 | self.num_feat, self.num_feat >> 1, 1, 1, padding=0, bias=False
32 | )
33 | self.o = SNConv2d(
34 | self.num_feat >> 1, self.num_feat, 1, 1, padding=0, bias=False
35 | )
36 |
37 | else:
38 | self.theta = nn.Conv2d(
39 | self.num_feat, self.num_feat >> 3, 1, 1, padding=0, bias=False
40 | )
41 | self.phi = nn.Conv2d(
42 | self.num_feat, self.num_feat >> 3, 1, 1, padding=0, bias=False
43 | )
44 | self.g = nn.Conv2d(
45 | self.num_feat, self.num_feat >> 1, 1, 1, padding=0, bias=False
46 | )
47 | self.o = nn.Conv2d(
48 | self.num_feat >> 1, self.num_feat, 1, 1, padding=0, bias=False
49 | )
50 |
51 | self.gamma = nn.Parameter(torch.tensor(0.0), requires_grad=True)
52 |
53 | def forward(self, x):
54 | """
55 | Feedforward function. Implementation differs from actual SAGAN paper,
56 | see note from BigGAN:
57 | https://github.com/ajbrock/BigGAN-PyTorch/blob/master/layers.py#L142
58 |
59 | See official TF Implementation:
60 | https://github.com/brain-research/self-attention-gan/blob/master/non_local.py
61 |
62 | Args:
63 | x (Tensor): Input feature map.
64 |
65 | Returns:
66 | Tensor: Feature map weighed with attention map.
67 | """
68 | N, C, H, W = x.shape
69 | location_num = H * W
70 | downsampled_num = location_num >> 2
71 |
72 | # Theta path
73 | theta = self.theta(x)
74 | theta = theta.view(N, C >> 3, location_num) # (N, C>>3, H*W)
75 |
76 | # Phi path
77 | phi = self.phi(x)
78 | phi = F.max_pool2d(phi, [2, 2], stride=2)
79 | phi = phi.view(N, C >> 3, downsampled_num) # (N, C>>3, H*W>>2)
80 |
81 | # Attention map
82 | attn = torch.bmm(theta.transpose(1, 2), phi)
83 | attn = F.softmax(attn, -1) # (N, H*W, H*W>>2)
84 | # print(torch.sum(attn, axis=2)) # (N, H*W)
85 |
86 | # Conv value
87 | g = self.g(x)
88 | g = F.max_pool2d(g, [2, 2], stride=2)
89 | g = g.view(N, C >> 1, downsampled_num) # (N, C>>1, H*W>>2)
90 |
91 | # Apply attention
92 | attn_g = torch.bmm(g, attn.transpose(1, 2)) # (N, C>>1, H*W)
93 | attn_g = attn_g.view(N, C >> 1, H, W) # (N, C>>1, H, W)
94 |
95 | # Project back feature size
96 | attn_g = self.o(attn_g)
97 |
98 | # Weigh attention map
99 | output = x + self.gamma * attn_g
100 |
101 | return output
102 |
103 |
104 | def SNConv2d(*args, default=True, **kwargs):
105 | r"""
106 | Wrapper for applying spectral norm on conv2d layer.
107 | """
108 | if default:
109 | return nn.utils.spectral_norm(nn.Conv2d(*args, **kwargs))
110 |
111 | else:
112 | return spectral_norm.SNConv2d(*args, **kwargs)
113 |
114 |
115 | def SNLinear(*args, default=True, **kwargs):
116 | r"""
117 | Wrapper for applying spectral norm on linear layer.
118 | """
119 | if default:
120 | return nn.utils.spectral_norm(nn.Linear(*args, **kwargs))
121 |
122 | else:
123 | return spectral_norm.SNLinear(*args, **kwargs)
124 |
125 |
126 | def SNEmbedding(*args, default=True, **kwargs):
127 | r"""
128 | Wrapper for applying spectral norm on embedding layer.
129 | """
130 | if default:
131 | return nn.utils.spectral_norm(nn.Embedding(*args, **kwargs))
132 |
133 | else:
134 | return spectral_norm.SNEmbedding(*args, **kwargs)
135 |
136 |
137 | class ConditionalBatchNorm2d(nn.Module):
138 | r"""
139 | Conditional Batch Norm as implemented in
140 | https://github.com/pytorch/pytorch/issues/8985
141 |
142 | Attributes:
143 | num_features (int): Size of feature map for batch norm.
144 | num_classes (int): Determines size of embedding layer to condition BN.
145 | """
146 |
147 | def __init__(self, num_features, num_classes):
148 | super().__init__()
149 | self.num_features = num_features
150 | self.bn = nn.BatchNorm2d(num_features, affine=False)
151 | self.embed = nn.Embedding(num_classes, num_features * 2)
152 | self.embed.weight.data[:, :num_features].normal_(
153 | 1, 0.02
154 | ) # Initialise scale at N(1, 0.02)
155 | self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0
156 |
157 | def forward(self, x, y):
158 | r"""
159 | Feedforwards for conditional batch norm.
160 |
161 | Args:
162 | x (Tensor): Input feature map.
163 | y (Tensor): Input class labels for embedding.
164 |
165 | Returns:
166 | Tensor: Output feature map.
167 | """
168 | out = self.bn(x)
169 | gamma, beta = self.embed(y).chunk(
170 | 2, 1
171 | ) # divide into 2 chunks, split from dim 1.
172 | out = gamma.view(-1, self.num_features, 1, 1) * out + beta.view(
173 | -1, self.num_features, 1, 1
174 | )
175 |
176 | return out
177 |
--------------------------------------------------------------------------------
/src/model/unet.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | import torch
3 | from torch import nn
4 |
5 |
6 | class UNet(nn.Module):
7 | def __init__(
8 | self,
9 | config,
10 | input_nc=8,
11 | output_nc=1,
12 | num_downs=5,
13 | ):
14 | """
15 | Construct a Unet generator
16 | Parameters:
17 | input_nc (int) -- the number of channels in input images
18 | output_nc (int) -- the number of channels in output images
19 | num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
20 | image of size 128x128 will become of size 1x1 # at the bottleneck
21 | channels (int) -- the number of filters in the last conv layer
22 | We construct the U-Net from the innermost layer to the outermost layer.
23 | It is a recursive process.
24 | """
25 | super().__init__()
26 | channels = config["channels"]
27 | unet_block = UnetSkipConnectionBlock(
28 | channels * 8,
29 | channels * 8,
30 | input_nc=None,
31 | submodule=None,
32 | norm_layer=nn.BatchNorm2d,
33 | innermost=True,
34 | ) # add the innermost layer first
35 |
36 | # Add intermediate layers with ngf * 8 filters
37 | for i in range(num_downs - 5):
38 | unet_block = UnetSkipConnectionBlock(
39 | channels * 8,
40 | channels * 8,
41 | input_nc=None,
42 | submodule=unet_block,
43 | norm_layer=nn.BatchNorm2d,
44 | use_dropout=config["use_dropout"],
45 | )
46 |
47 | # Gradually reduce the number of filters from `channels*8` to `channels`
48 | unet_block = UnetSkipConnectionBlock(
49 | channels * 4,
50 | channels * 8,
51 | input_nc=None,
52 | submodule=unet_block,
53 | norm_layer=nn.BatchNorm2d,
54 | )
55 | unet_block = UnetSkipConnectionBlock(
56 | channels * 2,
57 | channels * 4,
58 | input_nc=None,
59 | submodule=unet_block,
60 | norm_layer=nn.BatchNorm2d,
61 | )
62 | unet_block = UnetSkipConnectionBlock(
63 | channels,
64 | channels * 2,
65 | input_nc=None,
66 | submodule=unet_block,
67 | norm_layer=nn.BatchNorm2d,
68 | )
69 |
70 | self.model = UnetSkipConnectionBlock(
71 | output_nc,
72 | channels,
73 | input_nc=input_nc,
74 | submodule=unet_block,
75 | outermost=True,
76 | norm_layer=nn.BatchNorm2d,
77 | ) # add the outermost layer
78 |
79 | def forward(self, input):
80 | return self.model(input)
81 |
82 |
83 | class UnetSkipConnectionBlock(nn.Module):
84 | """
85 | Defines the Unet submodule with skip connection.
86 | X -------------------identity----------------------
87 | |-- downsampling -- |submodule| -- upsampling --|
88 |
89 | """
90 |
91 | def __init__(
92 | self,
93 | outer_nc,
94 | inner_nc,
95 | input_nc=None,
96 | submodule=None,
97 | outermost=False,
98 | innermost=False,
99 | norm_layer=nn.BatchNorm2d,
100 | use_dropout=False,
101 | ):
102 | """
103 | Construct a Unet submodule with skip connections.
104 |
105 | Parameters:
106 | outer_nc (int) -- the number of filters in the outer conv layer
107 | inner_nc (int) -- the number of filters in the inner conv layer
108 | input_nc (int) -- the number of channels in input images/features
109 | submodule (UnetSkipConnectionBlock) -- previously defined submodules
110 | outermost (bool) -- if this module is the outermost module
111 | innermost (bool) -- if this module is the innermost module
112 | norm_layer -- normalization layer
113 | use_dropout (bool) -- if use dropout layers.
114 | """
115 |
116 | # TODO: add spectral norm
117 |
118 | super().__init__()
119 | self.outermost = outermost
120 | use_bias = (
121 | (norm_layer.func == nn.InstanceNorm2d)
122 | if type(norm_layer) == partial
123 | else (norm_layer == nn.InstanceNorm2d)
124 | ) # use_bias is False if batch norm is used
125 | if input_nc is None:
126 | input_nc = outer_nc
127 | downconv = nn.Conv2d(
128 | input_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias
129 | )
130 | downrelu = nn.LeakyReLU(0.2, inplace=True)
131 | downnorm = norm_layer(inner_nc)
132 | uprelu = nn.ReLU(inplace=True)
133 | upnorm = norm_layer(outer_nc)
134 |
135 | if outermost:
136 | upconv = nn.ConvTranspose2d(
137 | inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=True
138 | )
139 | down = [downconv]
140 | up = [uprelu, upconv, nn.Sigmoid()]
141 | model = down + [submodule] + up
142 |
143 | elif innermost:
144 | upconv = nn.ConvTranspose2d(
145 | inner_nc, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias
146 | )
147 | down = [downrelu, downconv]
148 | up = [uprelu, upconv, upnorm]
149 | model = down + up
150 |
151 | else: # middle layers
152 | upconv = nn.ConvTranspose2d(
153 | inner_nc * 2,
154 | outer_nc,
155 | kernel_size=4,
156 | stride=2,
157 | padding=1,
158 | bias=use_bias,
159 | )
160 | down = [downrelu, downconv, downnorm]
161 | up = [uprelu, upconv, upnorm]
162 | model = down + [submodule] + up
163 | if use_dropout:
164 | model = model + [nn.Dropout(0.5)]
165 |
166 | self.model = nn.Sequential(*model)
167 |
168 | def forward(self, x):
169 | if self.outermost:
170 | return self.model(x)
171 | else: # add skip connections
172 | return torch.cat([x, self.model(x)], dim=1)
173 |
--------------------------------------------------------------------------------