├── src ├── utils │ ├── __init__.py │ └── utils.py ├── scripts │ ├── bcsunet_all.sh │ ├── scsnet_all.sh │ ├── reconnet_all.sh │ └── train_all.sh ├── model │ ├── bcsunet.py │ ├── utils.py │ ├── spectral_norm.py │ ├── upsamplenet.py │ ├── layers.py │ └── unet.py ├── data │ ├── base_dataset.py │ ├── svhn.py │ ├── emnist.py │ └── stl10.py ├── engine │ ├── dispatcher.py │ └── learner.py ├── evaluate.py ├── benchmark │ ├── reconnet │ │ ├── net.py │ │ ├── train.py │ │ ├── learner.py │ │ ├── inference_simulated.py │ │ └── inference_spi.py │ └── scsnet │ │ ├── net.py │ │ ├── train.py │ │ ├── inference_simulated.py │ │ ├── learner.py │ │ └── inference_spi.py ├── inference_simulated.py ├── train.py └── inference_spi.py ├── input └── phi_block_binary.npy ├── requirements.txt ├── config ├── inference_config.yaml ├── scsnet_config.yaml ├── reconnet_config.yaml ├── bcsunet_STL10.yaml ├── bcsunet_SVHN.yaml └── bcsunet_EMNIST.yaml ├── LICENSE ├── README.md └── .gitignore /src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import * 2 | -------------------------------------------------------------------------------- /input/phi_block_binary.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stephenllh/bcs-unet/HEAD/input/phi_block_binary.npy -------------------------------------------------------------------------------- /src/scripts/bcsunet_all.sh: -------------------------------------------------------------------------------- 1 | python train.py -d EMNIST 2 | python train.py -d SVHN 3 | python train.py -d STL10 -------------------------------------------------------------------------------- /src/scripts/scsnet_all.sh: -------------------------------------------------------------------------------- 1 | python -m benchmark.scsnet.train -d EMNIST 2 | python -m benchmark.scsnet.train -d SVHN 3 | python -m benchmark.scsnet.train -d STL10 -------------------------------------------------------------------------------- /src/scripts/reconnet_all.sh: -------------------------------------------------------------------------------- 1 | python -m benchmark.reconnet.train -d EMNIST 2 | python -m benchmark.reconnet.train -d SVHN 3 | # python -m benchmark.reconnet.train -d STL10 -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torchmetrics==0.3.1 2 | torch==1.8.1+cu111 3 | pytorch_lightning==1.3.0 4 | scipy==1.6.3 5 | torchvision==0.9.1+cu111 6 | numpy==1.20.2 7 | matplotlib==3.4.2 8 | opencv_python==4.5.2.52 9 | albumentations==0.5.2 10 | pandas==1.2.4 11 | PyYAML==5.4.1 12 | -------------------------------------------------------------------------------- /src/scripts/train_all.sh: -------------------------------------------------------------------------------- 1 | python -m benchmark.reconnet.train -d EMNIST 2 | python -m benchmark.reconnet.train -d SVHN 3 | python -m benchmark.reconnet.train -d STL10 4 | python -m benchmark.scsnet.train -d EMNIST 5 | python -m benchmark.scsnet.train -d SVHN 6 | python -m benchmark.scsnet.train -d STL10 7 | python train.py -d EMNIST 8 | python train.py -d SVHN 9 | python train.py -d STL10 -------------------------------------------------------------------------------- /config/inference_config.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | # This yaml file has the configurations for inference. 3 | 4 | gpu: True 5 | # dataset: "STL10" 6 | sampling_ratio: 0.125 7 | measurement_matrix: "../input/phi_block_binary.npy" 8 | real_data: 9 | - filename: "G.mat" 10 | min: 0.001 11 | max: 0.105 12 | - filename: "H.mat" 13 | min: 0.001 14 | max: 0.105 15 | - filename: "plus.mat" 16 | min: 0.01 17 | max: 0.105 18 | - filename: "S.mat" 19 | min: 0.001 20 | max: 0.105 21 | - filename: "R.mat" 22 | min: 0.001 23 | max: 0.12 24 | - filename: "F.mat" 25 | min: 0.002 26 | max: 0.4 -------------------------------------------------------------------------------- /config/scsnet_config.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | model: "SCSNet" 3 | bcs: True 4 | 5 | data_module: 6 | data_dir: "../input" 7 | batch_size: 128 8 | val_percent: 0.1 9 | num_workers: 0 10 | 11 | learner: 12 | criterion: "L1" 13 | val_metrics: ["psnr"] 14 | test_metrics: ["psnr", "ssim"] 15 | lr: 0.0003 16 | scheduler: 17 | type: "reduce_lr_on_plateau" # or "one_cycle" 18 | args_reduce_lr_on_plateau: 19 | factor: 0.3 20 | patience: 2 21 | verbose: True 22 | arg_one_cycle: 23 | pct_start: 0.3 24 | verbose: True 25 | 26 | callbacks: 27 | checkpoint: 28 | monitor: "val_loss" 29 | save_last: True 30 | filename: "best" 31 | early_stopping: 32 | monitor: "val_loss" 33 | patience: 5 34 | 35 | trainer: 36 | epochs: 50 37 | gpu: 1 38 | fp16: True 39 | -------------------------------------------------------------------------------- /config/reconnet_config.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | model: "reconnet" 3 | img_dim: 32 4 | bcs: True 5 | 6 | data_module: 7 | data_dir: "../input" 8 | batch_size: 128 9 | val_percent: 0.1 10 | num_workers: 4 11 | 12 | learner: 13 | criterion: "L1" 14 | val_metrics: ["psnr"] 15 | test_metrics: ["psnr", "ssim"] 16 | lr: 0.001 17 | scheduler: 18 | type: "reduce_lr_on_plateau" # or "one_cycle" 19 | args_reduce_lr_on_plateau: 20 | factor: 0.3 21 | patience: 2 22 | verbose: True 23 | arg_one_cycle: 24 | pct_start: 0.3 25 | verbose: True 26 | 27 | callbacks: 28 | checkpoint: 29 | monitor: "val_loss" 30 | save_last: True 31 | filename: "best" 32 | early_stopping: 33 | monitor: "val_loss" 34 | patience: 5 35 | 36 | trainer: 37 | epochs: 40 38 | gpu: 1 39 | fp16: True 40 | -------------------------------------------------------------------------------- /src/model/bcsunet.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from .upsamplenet import UpsampleNet 3 | from .unet import UNet 4 | from .utils import init_weights 5 | 6 | 7 | class BCSUNet(nn.Module): 8 | """Combines simple residual upsample model and U-Net model""" 9 | 10 | def __init__(self, config): 11 | super().__init__() 12 | upsamplenet = UpsampleNet( 13 | sampling_ratio=config["sampling_ratio"], 14 | upsamplenet_config=config["net"]["upsamplenet"], 15 | ) 16 | self.upsamplenet = init_weights( 17 | upsamplenet, init_type=config["net"]["upsamplenet"]["init_type"] 18 | ) 19 | unet = UNet(config["net"]["unet"], input_nc=8) 20 | self.unet = init_weights(unet, init_type=config["net"]["unet"]["init_type"]) 21 | 22 | def forward(self, x): 23 | x = self.upsamplenet(x) 24 | out = self.unet(x) 25 | return out 26 | -------------------------------------------------------------------------------- /config/bcsunet_STL10.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | model: "BCS-UNet" 3 | bcs: True 4 | 5 | data_module: 6 | data_dir: "../input" 7 | batch_size: 64 8 | val_percent: 0.1 9 | num_workers: 0 10 | 11 | net: 12 | upsamplenet: 13 | init_type: "xavier" 14 | out_channels_1: 16 15 | out_channels_2: 8 16 | use_transpose_conv: False 17 | spectral_norm: True 18 | unet: 19 | init_type: "xavier" 20 | channels: 16 21 | use_dropout: True 22 | 23 | learner: 24 | intermediate_image: True 25 | criterion: "L1" 26 | val_metrics: ["psnr"] 27 | test_metrics: ["psnr", "ssim"] 28 | lr: 0.0005 29 | weight_decay: 0.0 30 | scheduler: 31 | type: "reduce_lr_on_plateau" # or "one_cycle" 32 | args_reduce_lr_on_plateau: 33 | factor: 0.3 34 | patience: 2 35 | verbose: True 36 | arg_one_cycle: 37 | pct_start: 0.3 38 | verbose: True 39 | 40 | callbacks: 41 | checkpoint: 42 | monitor: "val_loss" 43 | save_last: True 44 | filename: "best" 45 | early_stopping: 46 | monitor: "val_loss" 47 | patience: 6 48 | 49 | trainer: 50 | epochs: 50 51 | gpu: 1 52 | fp16: True 53 | -------------------------------------------------------------------------------- /config/bcsunet_SVHN.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | model: "BCS-UNet" 3 | bcs: True 4 | 5 | data_module: 6 | data_dir: "../input" 7 | batch_size: 128 8 | val_percent: 0.1 9 | num_workers: 4 10 | 11 | net: 12 | upsamplenet: 13 | init_type: "xavier" 14 | out_channels_1: 16 15 | out_channels_2: 8 16 | use_transpose_conv: False 17 | spectral_norm: True 18 | unet: 19 | init_type: "xavier" 20 | channels: 16 21 | use_dropout: False 22 | 23 | learner: 24 | intermediate_image: True 25 | criterion: "L1" 26 | val_metrics: ["psnr"] 27 | test_metrics: ["psnr", "ssim"] 28 | lr: 0.001 29 | weight_decay: 0.001 30 | scheduler: 31 | type: "reduce_lr_on_plateau" # or "one_cycle" 32 | args_reduce_lr_on_plateau: 33 | factor: 0.3 34 | patience: 2 35 | verbose: True 36 | arg_one_cycle: 37 | pct_start: 0.3 38 | verbose: True 39 | 40 | callbacks: 41 | checkpoint: 42 | monitor: "val_loss" 43 | save_last: True 44 | filename: "best" 45 | early_stopping: 46 | monitor: "val_loss" 47 | patience: 5 48 | 49 | trainer: 50 | epochs: 50 51 | gpu: 1 52 | fp16: True 53 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Stephen Lau 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /config/bcsunet_EMNIST.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | model: "BCS-UNet" 3 | sampling_ratio: 0.125 4 | bcs: True 5 | 6 | data_module: 7 | data_dir: "../input" 8 | batch_size: 128 9 | val_percent: 0.1 10 | num_workers: 4 11 | 12 | net: 13 | upsamplenet: 14 | init_type: "xavier" 15 | out_channels_1: 16 16 | out_channels_2: 8 17 | use_transpose_conv: False 18 | spectral_norm: True 19 | unet: 20 | init_type: "xavier" 21 | channels: 16 22 | use_dropout: False 23 | 24 | learner: 25 | intermediate_image: True 26 | criterion: "L1" 27 | val_metrics: ["psnr"] 28 | test_metrics: ["psnr", "ssim"] 29 | lr: 0.0003 30 | weight_decay: 0.01 31 | scheduler: 32 | type: "reduce_lr_on_plateau" # or "one_cycle" 33 | args_reduce_lr_on_plateau: 34 | factor: 0.3 35 | patience: 2 36 | verbose: True 37 | arg_one_cycle: 38 | pct_start: 0.3 39 | verbose: True 40 | 41 | callbacks: 42 | checkpoint: 43 | monitor: "val_loss" 44 | save_last: True 45 | filename: "best" 46 | early_stopping: 47 | monitor: "val_loss" 48 | patience: 5 49 | 50 | trainer: 51 | epochs: 50 52 | gpu: 1 53 | fp16: True 54 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # BCS-UNet 2 | 3 | 4 | ## About The Project 5 | 6 |
7 |

8 | Image 9 |

10 | 11 | 12 | 13 | ## Getting Started 14 | 15 | To get a local copy up and running, follow these simple example steps. 16 |

17 | 18 | ### Installation 19 | 20 | 1. Clone the repo 21 | ```sh 22 | git clone https://github.com/stephenllh/bcs-unet.git 23 | ``` 24 | 25 | 1. Change directory 26 | ```sh 27 | cd bcs-unet 28 | ``` 29 | 30 | 2. Install packages 31 | ```sh 32 | pip install requirements.txt 33 | ``` 34 |
35 | 36 | 37 | ## Usage 38 | 39 | 1. To train, run the training script 40 | ```sh 41 | cd .. 42 | python train.py 43 | ``` 44 |
45 | 46 | 47 | 48 | ## License 49 | 50 | Distributed under the MIT License. See `LICENSE` for more information. 51 |

52 | 53 | 54 | 55 | ## Contact 56 | 57 | Stephen Lau - [Email](stephenlaulh@gmail.com) - [Twitter](https://twitter.com/StephenLLH) - [Kaggle](https://www.kaggle.com/faraksuli) 58 | 59 | 60 | -------------------------------------------------------------------------------- /src/data/base_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class BaseDataset: 8 | def __init__(self, sampling_ratio: float, bcs: bool): 9 | if bcs: 10 | phi = np.load("../input/phi_block_binary.npy") 11 | phi = phi[: int(sampling_ratio * 16)] 12 | phi = torch.FloatTensor(phi) 13 | # print("phi_shape", phi.shape) 14 | self.cs_operator = BCSOperator(phi) 15 | else: # TODO: prepare for CCS 16 | pass 17 | # phi = np.load("../input/________.npy") 18 | # self.cs_operator = CCSOperator(phi) 19 | 20 | 21 | class BCSOperator(nn.Module): 22 | """ 23 | This CNN is not trainable. It serves as a block compressive sensing operator 24 | by utilizing its optimized convolution operations. 25 | """ 26 | 27 | def __init__(self, phi): 28 | super().__init__() 29 | self.register_buffer("phi", phi) 30 | 31 | def forward(self, x): 32 | out = F.conv2d(x, self.phi, stride=4) 33 | return out 34 | 35 | 36 | class CCSOperator(nn.Module): 37 | """ 38 | This CNN is not trainable. It serves as a full-image conventional compressive sensing operator 39 | by utilizing its optimized convolution operations. 40 | """ 41 | 42 | def __init__(self, phi): 43 | super().__init__() 44 | self.register_buffer("phi", phi) 45 | 46 | def forward(self, x): 47 | assert x.shape[-1] == self.phi.shape[-1] 48 | out = F.conv2d(x, self.phi, stride=1) 49 | return out 50 | -------------------------------------------------------------------------------- /src/engine/dispatcher.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.optim import lr_scheduler 3 | import torchmetrics 4 | 5 | 6 | def get_scheduler(optimizer, config): 7 | scheduler_config = config["learner"]["scheduler"] 8 | scheduler_type = scheduler_config["type"] 9 | 10 | if scheduler_type == "reduce_lr_on_plateau": 11 | scheduler = lr_scheduler.ReduceLROnPlateau( 12 | optimizer, **scheduler_config["args_reduce_lr_on_plateau"] 13 | ) 14 | 15 | # TODO: need to work on this (calculate the total number of steps) 16 | elif scheduler_type == "one_cycle": 17 | scheduler = lr_scheduler.OneCycleLR( 18 | optimizer, 19 | **scheduler_config["arg_one_cycle"], 20 | epochs=config["trainer"]["epochs"], 21 | steps_per_epoch=100 22 | ) 23 | 24 | else: 25 | raise NotImplementedError("This scheduler is not implemented.") 26 | 27 | return scheduler 28 | 29 | 30 | def get_criterion(config): 31 | if config["learner"]["criterion"] == "cross_entropy": 32 | return nn.CrossEntropyLoss() 33 | 34 | elif config["learner"]["criterion"] == "L1": 35 | return nn.L1Loss() 36 | 37 | else: 38 | raise NotImplementedError("This loss function is not implemented.") 39 | 40 | 41 | def get_metrics(metric_name, config): 42 | if metric_name == "psnr": 43 | return torchmetrics.PSNR(data_range=1.0, dim=(-2, -1)) 44 | 45 | elif metric_name == "ssim": 46 | return torchmetrics.SSIM(data_range=1.0) 47 | 48 | else: 49 | raise NotImplementedError("This metric is not implemented.") 50 | -------------------------------------------------------------------------------- /src/evaluate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import pytorch_lightning as pl 4 | from pytorch_lightning.utilities.seed import seed_everything 5 | from data.emnist import EMNISTDataModule 6 | from data.svhn import SVHNDataModule 7 | from data.stl10 import STL10DataModule 8 | from engine.learner import BCSUNetLearner 9 | from utils import load_config 10 | 11 | 12 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" 13 | 14 | parser = argparse.ArgumentParser(description="Wheat detection with EfficientDet") 15 | parser.add_argument("-d", "--dataset", type=str, help="'EMNIST', 'SVHN', or 'STL10'") 16 | parser.add_argument( 17 | "-s", 18 | "--sampling_ratio", 19 | type=float, 20 | required=True, 21 | help="Sampling ratio in percentage", 22 | ) 23 | 24 | args = parser.parse_args() 25 | 26 | 27 | def run(): 28 | seed_everything(seed=0, workers=True) 29 | path = f"../logs/BCSUNet_{args.dataset}_{int(args.sampling_ratio * 10000)}" 30 | config = load_config(f"{path}/version_0/hparams.yaml") 31 | 32 | if args.dataset == "EMNIST": 33 | data_module = EMNISTDataModule(config) 34 | elif args.dataset == "SVHN": 35 | data_module = SVHNDataModule(config) 36 | elif args.dataset == "STL10": 37 | data_module = STL10DataModule(config) 38 | else: 39 | raise NotImplementedError 40 | 41 | PATH = f"{path}/version_0/checkpoints/best.ckpt" 42 | learner = BCSUNetLearner.load_from_checkpoint(PATH, config) 43 | 44 | trainer = pl.Trainer( 45 | gpus=1, 46 | default_root_dir="../", 47 | logger=False, 48 | ) 49 | trainer.test(learner, datamodule=data_module) 50 | 51 | 52 | if __name__ == "__main__": 53 | run() 54 | -------------------------------------------------------------------------------- /src/benchmark/reconnet/net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class ReconNet(nn.Module): 6 | def __init__(self, num_measurements, img_dim): 7 | super().__init__() 8 | self.img_dim = img_dim 9 | self.linear = nn.Linear(num_measurements, img_dim * img_dim) 10 | self.bn = nn.BatchNorm1d(img_dim * img_dim) 11 | self.convs = nn.Sequential( 12 | ConvBlock(in_channels=1, out_channels=64, kernel_size=11), 13 | ConvBlock(in_channels=64, out_channels=32, kernel_size=1), 14 | ConvBlock(in_channels=32, out_channels=1, kernel_size=7), 15 | ConvBlock(in_channels=1, out_channels=64, kernel_size=11), 16 | ConvBlock(in_channels=64, out_channels=32, kernel_size=1), 17 | ) 18 | self.final_conv = nn.Conv2d(32, 1, kernel_size=7, padding=7 // 2) 19 | 20 | def forward(self, x): 21 | if x.dim() == 4: # BCS 22 | x = x.view(x.shape[0], -1) 23 | x = x.unsqueeze(dim=1) 24 | x = self.linear(x) 25 | x = x.view(-1, 1, self.img_dim, self.img_dim) 26 | x = self.convs(x) 27 | x = self.final_conv(x) 28 | out = torch.sigmoid(x) 29 | return out 30 | 31 | 32 | class ConvBlock(nn.Module): 33 | def __init__(self, in_channels, out_channels, kernel_size, stride=1): 34 | super().__init__() 35 | self.conv = nn.Conv2d( 36 | in_channels, 37 | out_channels, 38 | kernel_size=kernel_size, 39 | stride=stride, 40 | padding=kernel_size // 2, 41 | bias=True, 42 | ) 43 | self.relu = nn.ReLU(inplace=True) 44 | self.bn = nn.BatchNorm2d(out_channels) 45 | 46 | def forward(self, x): 47 | x = self.conv(x) 48 | x = self.bn(x) 49 | out = self.relu(x) 50 | return out 51 | -------------------------------------------------------------------------------- /src/model/utils.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | def init_weights(net, init_type, init_gain=0.02): 5 | """ 6 | Initialize network weights. 7 | Parameters: 8 | net (network) -- network to be initialized 9 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal 10 | init_gain (float) -- scaling factor for normal, xavier and orthogonal. 11 | """ 12 | 13 | def init_func(m): # define the initialization function 14 | classname = m.__class__.__name__ 15 | if hasattr(m, "weight") and ( 16 | classname.find("Conv") != -1 or classname.find("Linear") != -1 17 | ): 18 | 19 | if init_type == "normal": 20 | nn.init.normal_(m.weight.data, 0.0, init_gain) 21 | 22 | elif init_type == "xavier": 23 | nn.init.xavier_normal_(m.weight.data, gain=init_gain) 24 | 25 | elif init_type == "kaiming": 26 | nn.init.kaiming_normal_(m.weight.data, a=0, mode="fan_in") 27 | 28 | elif init_type == "orthogonal": 29 | nn.init.orthogonal_(m.weight.data, gain=init_gain) 30 | 31 | else: 32 | raise NotImplementedError( 33 | "Initialization method [%s] is not implemented" % init_type 34 | ) 35 | 36 | if hasattr(m, "bias") and m.bias is not None: 37 | nn.init.constant_(m.bias.data, 0.0) 38 | 39 | elif ( 40 | classname.find("BatchNorm2d") != -1 41 | ): # BatchNorm Layer's weight is not a matrix; only normal distribution applies. 42 | nn.init.normal_(m.weight.data, 1.0, init_gain) 43 | nn.init.constant_(m.bias.data, 0.0) 44 | 45 | # print("Initialize network with %s" % init_type) 46 | net.apply(init_func) # apply the initialization function 47 | return net 48 | -------------------------------------------------------------------------------- /src/utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | import torch 4 | import math 5 | import cv2 6 | 7 | 8 | def voltage2pixel(y, phi, low, high): 9 | """ 10 | Converts the measurement tensor / vector from voltage scale to pixel scale, 11 | so that the neural network can understand. 12 | y - measurement 13 | phi - block measurement matrix 14 | low - lowest voltage pixel 15 | high - highest voltage pixel 16 | """ 17 | 18 | if type(y) == "numpy.ndarray": 19 | y = torch.from_numpy(y).float() 20 | 21 | phi = torch.from_numpy(phi).float() 22 | 23 | if y.dim() == 4 and y.shape[0] == 1: 24 | y = y.squeeze(dim=0) 25 | 26 | if phi.dim() == 4 and phi.shape[1] == 1: 27 | phi = phi.squeeze(dim=1) 28 | 29 | term1 = y / (high - low) 30 | term2 = (phi.sum(dim=(1, 2)) * low / (high - low)).unsqueeze(-1).unsqueeze(-1) 31 | 32 | y_pixel_scale = term1 - term2 33 | return y_pixel_scale 34 | 35 | 36 | def load_config(config_path): 37 | with open(os.path.join(config_path)) as file: 38 | config = yaml.safe_load(file) 39 | return config 40 | 41 | 42 | def reshape_into_block(y, sr: float, block_size=4): 43 | c = int(sr * block_size ** 2) 44 | h = w = int(math.sqrt(y.shape[0] // c)) 45 | y = y.reshape(h, w, c) 46 | y = y.transpose((2, 0, 1)) 47 | return y 48 | 49 | 50 | def create_patches(root_dir, size): 51 | print(f"Creating {size}x{size} image patches for STL10 test set.") 52 | dataset_dir = os.path.join(root_dir, "test_images") 53 | save_dir = os.path.join(root_dir, "test_images_32x32") 54 | 55 | if not os.path.exists(save_dir): 56 | os.mkdir(save_dir) 57 | 58 | if not os.path.exists(dataset_dir): 59 | os.mkdir(dataset_dir) 60 | 61 | filenames = os.listdir(dataset_dir) 62 | for filename in filenames: 63 | img = cv2.imread(f"{dataset_dir}/{filename}") 64 | img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) 65 | img = img.reshape(img.shape[0] // size, size, img.shape[1] // size, size) 66 | img = img.transpose(0, 2, 1, 3).reshape(-1, size, size) 67 | for idx, img_patch in enumerate(img): 68 | save_filename = f"{save_dir}/{filename.split('.')[0]}_{idx}.png" 69 | cv2.imwrite(save_filename, img_patch) 70 | -------------------------------------------------------------------------------- /src/benchmark/scsnet/net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class SCSNetInit(nn.Module): 6 | """The "initial reconstruction network" of SCSNet""" 7 | 8 | def __init__(self, in_channels, block_size=4): 9 | super().__init__() 10 | self.block_size = block_size 11 | self.conv = nn.Conv2d(in_channels, block_size ** 2, kernel_size=1) 12 | 13 | def forward(self, x): 14 | x = self.conv(x) 15 | out = self._permute(x) 16 | return out 17 | 18 | def _permute(self, x): 19 | B, C, H, W = x.shape 20 | x = x.permute(0, 2, 3, 1) 21 | x = x.view(B, H, W, self.block_size, self.block_size) 22 | x = x.permute(0, 1, 3, 2, 4).contiguous() 23 | out = x.view(-1, 1, H * self.block_size, W * self.block_size) 24 | return out 25 | 26 | 27 | class SCSNetDeep(nn.Module): 28 | """The "deep reconstruction network" of SCSNet""" 29 | 30 | def __init__(self): 31 | super().__init__() 32 | middle_convs = [ 33 | ConvBlock(in_channels=32, out_channels=32, kernel_size=3) for _ in range(13) 34 | ] 35 | self.convs = nn.Sequential( 36 | ConvBlock(in_channels=1, out_channels=128, kernel_size=3), 37 | ConvBlock(in_channels=128, out_channels=32, kernel_size=3), 38 | *middle_convs, 39 | ConvBlock(in_channels=32, out_channels=128, kernel_size=3), 40 | nn.Conv2d( 41 | in_channels=128, 42 | out_channels=1, 43 | kernel_size=3, 44 | padding=1, 45 | bias=False, 46 | ), 47 | ) 48 | 49 | def forward(self, x): 50 | out = self.convs(x) 51 | return torch.sigmoid(x + out) 52 | 53 | 54 | class ConvBlock(nn.Module): 55 | def __init__(self, in_channels, out_channels, kernel_size, stride=1): 56 | super().__init__() 57 | self.conv = nn.Conv2d( 58 | in_channels, 59 | out_channels, 60 | kernel_size=kernel_size, 61 | stride=stride, 62 | padding=kernel_size // 2, 63 | bias=True, 64 | ) 65 | self.relu = nn.ReLU(inplace=True) 66 | self.bn = nn.BatchNorm2d(out_channels) 67 | 68 | def forward(self, x): 69 | x = self.conv(x) 70 | x = self.bn(x) 71 | out = self.relu(x) 72 | return out 73 | -------------------------------------------------------------------------------- /src/inference_simulated.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import argparse 4 | from pathlib import Path 5 | import warnings 6 | import cv2 7 | import scipy.ndimage 8 | from data.emnist import EMNISTDataModule 9 | from data.svhn import SVHNDataModule 10 | from data.stl10 import STL10DataModule 11 | from engine.learner import BCSUNetLearner 12 | from utils import load_config 13 | 14 | 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument( 17 | "-d", "--dataset", type=str, required=True, help="'EMNIST', 'SVHN', or 'STL10'" 18 | ) 19 | parser.add_argument( 20 | "-s", 21 | "--sampling_ratio", 22 | type=float, 23 | required=True, 24 | help="Sampling ratio in percentage", 25 | ) 26 | args = parser.parse_args() 27 | 28 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" 29 | warnings.simplefilter("ignore") 30 | 31 | 32 | def run(): 33 | dataset = args.dataset 34 | sr = args.sampling_ratio 35 | checkpoint_path = ( 36 | f"../logs/BCSUNet_{dataset}_{int(sr * 100):04d}/best/checkpoints/last.ckpt" 37 | ) 38 | train_config_path = os.path.join( 39 | Path(checkpoint_path).parent.parent, "hparams.yaml" 40 | ) 41 | train_config = load_config(train_config_path) 42 | train_config["sampling_ratio"] = sr / 100 43 | 44 | if dataset == "EMNIST": 45 | data_module = EMNISTDataModule(train_config) 46 | elif dataset == "SVHN": 47 | data_module = SVHNDataModule(train_config) 48 | elif dataset == "STL10": 49 | data_module = STL10DataModule(train_config) 50 | 51 | learner = BCSUNetLearner.load_from_checkpoint( 52 | checkpoint_path=checkpoint_path, config=train_config, strict=False 53 | ) 54 | 55 | message = f"Inference: BCS-UNet on {dataset} dataset. Sampling ratio = {train_config['sampling_ratio']}" 56 | print(message) 57 | 58 | data_module.setup() 59 | ds = data_module.test_dataset 60 | 61 | directory = f"../inference_images/BCSUNet/{dataset}/{int(sr * 100):04d}" 62 | os.makedirs(directory, exist_ok=True) 63 | 64 | for i in np.linspace(0, len(ds) - 1, 30, dtype=int): 65 | input = ds[i][0].unsqueeze(0) 66 | out = learner(input).squeeze().squeeze().detach().numpy() 67 | out = scipy.ndimage.zoom(out, 8, order=0, mode="nearest") 68 | cv2.imwrite(f"{directory}/{i}.png", out * 255) 69 | 70 | print("Done.") 71 | 72 | 73 | if __name__ == "__main__": 74 | run() 75 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import pytorch_lightning as pl 4 | from pytorch_lightning.callbacks import ( 5 | ModelCheckpoint, 6 | EarlyStopping, 7 | LearningRateMonitor, 8 | ) 9 | from pytorch_lightning.loggers import TensorBoardLogger 10 | from pytorch_lightning.utilities.seed import seed_everything 11 | from data.emnist import EMNISTDataModule 12 | from data.svhn import SVHNDataModule 13 | from data.stl10 import STL10DataModule 14 | from engine.learner import BCSUNetLearner 15 | from utils import load_config 16 | 17 | 18 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" 19 | 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument( 22 | "-d", "--dataset", type=str, required=True, help="'EMNIST', 'SVHN', or 'STL10'" 23 | ) 24 | parser.add_argument( 25 | "-s", 26 | "--sampling_ratio", 27 | type=float, 28 | required=True, 29 | help="Sampling ratio in percentage", 30 | ) 31 | args = parser.parse_args() 32 | 33 | 34 | def run(): 35 | seed_everything(seed=0, workers=True) 36 | 37 | config = load_config(f"../config/bcsunet_{args.dataset}.yaml") 38 | config["sampling_ratio"] = args.sampling_ratio / 100 39 | 40 | if args.dataset == "EMNIST": 41 | data_module = EMNISTDataModule(config) 42 | elif args.dataset == "SVHN": 43 | data_module = SVHNDataModule(config) 44 | elif args.dataset == "STL10": 45 | data_module = STL10DataModule(config) 46 | 47 | learner = BCSUNetLearner(config) 48 | callbacks = [ 49 | ModelCheckpoint(**config["callbacks"]["checkpoint"]), 50 | EarlyStopping(**config["callbacks"]["early_stopping"]), 51 | LearningRateMonitor(), 52 | ] 53 | log_name = f"BCSUNet_{args.dataset}_{int(config['sampling_ratio'] * 10000):04d}" 54 | logger = TensorBoardLogger(save_dir="../logs", name=log_name) 55 | 56 | message = f"Running BCS-UNet on {args.dataset} dataset. Sampling ratio = {config['sampling_ratio']}" 57 | print("-" * 100) 58 | print(message) 59 | print("-" * 100) 60 | 61 | trainer = pl.Trainer( 62 | gpus=config["trainer"]["gpu"], 63 | max_epochs=config["trainer"]["epochs"], 64 | default_root_dir="../", 65 | callbacks=callbacks, 66 | precision=(16 if config["trainer"]["fp16"] else 32), 67 | logger=logger, 68 | ) 69 | trainer.fit(learner, data_module) 70 | trainer.test(learner, datamodule=data_module, ckpt_path="best") 71 | 72 | 73 | if __name__ == "__main__": 74 | run() 75 | -------------------------------------------------------------------------------- /src/benchmark/scsnet/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import pytorch_lightning as pl 4 | from pytorch_lightning.callbacks import ( 5 | ModelCheckpoint, 6 | EarlyStopping, 7 | LearningRateMonitor, 8 | ) 9 | from pytorch_lightning.loggers import TensorBoardLogger 10 | from pytorch_lightning.utilities.seed import seed_everything 11 | from data.emnist import EMNISTDataModule 12 | from data.svhn import SVHNDataModule 13 | from data.stl10 import STL10DataModule 14 | from .learner import SCSNetLearner 15 | from utils import load_config 16 | 17 | 18 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" 19 | 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument("-d", "--dataset", type=str, help="'EMNIST', 'SVHN', or 'STL10'") 22 | parser.add_argument( 23 | "-s", 24 | "--sampling_ratio", 25 | type=float, 26 | required=True, 27 | help="Sampling ratio in percentage", 28 | ) 29 | args = parser.parse_args() 30 | 31 | 32 | def run(): 33 | seed_everything(seed=0, workers=True) 34 | 35 | config = load_config("../config/scsnet_config.yaml") 36 | config["sampling_ratio"] = args.sampling_ratio / 100 37 | 38 | if args.dataset == "EMNIST": 39 | data_module = EMNISTDataModule(config) 40 | elif args.dataset == "SVHN": 41 | data_module = SVHNDataModule(config) 42 | elif args.dataset == "STL10": 43 | config["data_module"]["batch_size"] = 64 44 | data_module = STL10DataModule(config) 45 | else: 46 | raise NotImplementedError 47 | 48 | learner = SCSNetLearner(config) 49 | callbacks = [ 50 | ModelCheckpoint(**config["callbacks"]["checkpoint"]), 51 | EarlyStopping(**config["callbacks"]["early_stopping"]), 52 | LearningRateMonitor(), 53 | ] 54 | 55 | log_name = f"SCSNet_{args.dataset}_{int(config['sampling_ratio'] * 10000):04d}" 56 | logger = TensorBoardLogger(save_dir="../logs", name=log_name) 57 | 58 | message = f"Running SCSNet on {args.dataset} dataset. Sampling ratio = {config['sampling_ratio']}" 59 | print("-" * 100) 60 | print(message) 61 | print("-" * 100) 62 | 63 | trainer = pl.Trainer( 64 | gpus=config["trainer"]["gpu"], 65 | max_epochs=config["trainer"]["epochs"], 66 | default_root_dir="../", 67 | callbacks=callbacks, 68 | precision=(16 if config["trainer"]["fp16"] else 32), 69 | logger=logger, 70 | ) 71 | trainer.fit(learner, data_module) 72 | trainer.test(learner, datamodule=data_module, ckpt_path="best") 73 | 74 | 75 | if __name__ == "__main__": 76 | run() 77 | -------------------------------------------------------------------------------- /src/benchmark/scsnet/inference_simulated.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import numpy as np 4 | from pathlib import Path 5 | import warnings 6 | import cv2 7 | import scipy.ndimage 8 | from data.emnist import EMNISTDataModule 9 | from data.svhn import SVHNDataModule 10 | from data.stl10 import STL10DataModule 11 | from .learner import SCSNetLearner 12 | from utils import load_config 13 | 14 | 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument( 17 | "-d", "--dataset", type=str, required=True, help="'EMNIST', 'SVHN', or 'STL10'" 18 | ) 19 | parser.add_argument( 20 | "-s", 21 | "--sampling_ratio", 22 | type=float, 23 | required=True, 24 | help="Sampling ratio in percentage", 25 | ) 26 | args = parser.parse_args() 27 | 28 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" 29 | warnings.simplefilter("ignore") 30 | 31 | 32 | def run(): 33 | # inference_config = load_config("../config/inference_config.yaml") 34 | dataset = args.dataset 35 | sr = args.sampling_ratio 36 | checkpoint_path = ( 37 | f"../logs/SCSNet_{dataset}_{int(sr * 100):04d}/best/checkpoints/last.ckpt" 38 | ) 39 | train_config_path = os.path.join( 40 | Path(checkpoint_path).parent.parent, "hparams.yaml" 41 | ) 42 | train_config = load_config(train_config_path) 43 | train_config["sampling_ratio"] = sr / 100 44 | 45 | if dataset == "EMNIST": 46 | data_module = EMNISTDataModule(train_config) 47 | train_config["img_dim"] = 32 48 | elif dataset == "SVHN": 49 | data_module = SVHNDataModule(train_config) 50 | train_config["img_dim"] = 32 51 | elif dataset == "STL10": 52 | data_module = STL10DataModule(train_config) 53 | train_config["img_dim"] = 96 54 | 55 | learner = SCSNetLearner.load_from_checkpoint( 56 | checkpoint_path=checkpoint_path, config=train_config, strict=False 57 | ) 58 | 59 | message = f"Inference: SCSNet on {dataset} dataset. Sampling ratio = {train_config['sampling_ratio']}" 60 | print(message) 61 | 62 | data_module.setup() 63 | ds = data_module.test_dataset 64 | 65 | directory = f"../inference_images/SCSNet/{dataset}/{int(sr * 100):04d}" 66 | os.makedirs(directory, exist_ok=True) 67 | 68 | for i in np.linspace(0, len(ds) - 1, 30, dtype=int): 69 | input = ds[i][0].unsqueeze(0) 70 | out = learner(input).squeeze().squeeze().detach().numpy() 71 | out = scipy.ndimage.zoom(out, 8, order=0, mode="nearest") 72 | cv2.imwrite(f"{directory}/{i}.png", out * 255) 73 | 74 | print("Done.") 75 | 76 | 77 | if __name__ == "__main__": 78 | run() 79 | -------------------------------------------------------------------------------- /src/benchmark/reconnet/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import pytorch_lightning as pl 4 | from pytorch_lightning.callbacks import ( 5 | ModelCheckpoint, 6 | EarlyStopping, 7 | LearningRateMonitor, 8 | ) 9 | from pytorch_lightning.loggers import TensorBoardLogger 10 | from pytorch_lightning.utilities.seed import seed_everything 11 | from data.emnist import EMNISTDataModule 12 | from data.svhn import SVHNDataModule 13 | from data.stl10 import STL10DataModule 14 | from .learner import ReconNetLearner 15 | from utils import load_config 16 | 17 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" 18 | 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument("-d", "--dataset", type=str, help="'EMNIST', 'SVHN', or 'STL10'") 21 | parser.add_argument( 22 | "-s", 23 | "--sampling_ratio", 24 | type=float, 25 | required=True, 26 | help="Sampling ratio in percentage", 27 | ) 28 | args = parser.parse_args() 29 | 30 | 31 | def run(): 32 | seed_everything(seed=0, workers=True) 33 | 34 | config = load_config("../config/reconnet_config.yaml") 35 | config["sampling_ratio"] = args.sampling_ratio / 100 36 | 37 | if args.dataset == "EMNIST": 38 | data_module = EMNISTDataModule(config) 39 | elif args.dataset == "SVHN": 40 | data_module = SVHNDataModule(config) 41 | elif args.dataset == "STL10": 42 | config["data_module"]["num_workers"] = 0 43 | data_module = STL10DataModule(config, reconnet=True) 44 | else: 45 | raise NotImplementedError 46 | 47 | learner = ReconNetLearner(config) 48 | callbacks = [ 49 | ModelCheckpoint(**config["callbacks"]["checkpoint"]), 50 | EarlyStopping(**config["callbacks"]["early_stopping"]), 51 | LearningRateMonitor(), 52 | ] 53 | 54 | sampling_ratio = config["sampling_ratio"] 55 | log_name = f"ReconNet_{args.dataset}_{int(sampling_ratio * 10000):04d}" 56 | logger = TensorBoardLogger(save_dir="../logs", name=log_name) 57 | 58 | message = f"Running ReconNet on {args.dataset} dataset. Sampling ratio = {sampling_ratio * 100}%" 59 | print("-" * 100) 60 | print(message) 61 | print("-" * 100) 62 | 63 | trainer = pl.Trainer( 64 | gpus=config["trainer"]["gpu"], 65 | max_epochs=config["trainer"]["epochs"], 66 | default_root_dir="../", 67 | progress_bar_refresh_rate=20, 68 | callbacks=callbacks, 69 | precision=(16 if config["trainer"]["fp16"] else 32), 70 | logger=logger, 71 | ) 72 | trainer.fit(learner, data_module) 73 | trainer.test(learner, datamodule=data_module, ckpt_path="best") 74 | 75 | 76 | if __name__ == "__main__": 77 | run() 78 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # pytype static type analyzer 135 | .pytype/ 136 | 137 | # Cython debug symbols 138 | cython_debug/ 139 | 140 | # VSCode 141 | **/.vscode 142 | **/__pycache__ 143 | 144 | # Input data and models 145 | **/input/** 146 | !/input/phi_block_binary.npy 147 | **/logs 148 | **/lightning_logs 149 | **/inference_input 150 | **/inference_images 151 | **/wandb 152 | **/saved_models 153 | **/src/tests 154 | 155 | # Data files 156 | *.csv 157 | *.h5 158 | *.pkl 159 | *.pth 160 | *.bin 161 | *.pyc 162 | *.png 163 | *.ckpt 164 | -------------------------------------------------------------------------------- /src/engine/learner.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytorch_lightning as pl 3 | from model.bcsunet import BCSUNet 4 | from engine.dispatcher import get_scheduler, get_criterion, get_metrics 5 | 6 | 7 | class BCSUNetLearner(pl.LightningModule): 8 | def __init__(self, config): 9 | super().__init__() 10 | self.net = BCSUNet(config) 11 | self.config = config 12 | self.criterion = get_criterion(config) 13 | self._set_metrics(config) 14 | self.save_hyperparameters(config) 15 | 16 | def forward(self, inputs): 17 | reconstructed_image = self.net(inputs) 18 | return reconstructed_image 19 | 20 | def configure_optimizers(self): 21 | optimizer = torch.optim.AdamW( 22 | self.parameters(), 23 | lr=self.config["learner"]["lr"], 24 | weight_decay=self.config["learner"]["weight_decay"], 25 | ) 26 | scheduler = get_scheduler(optimizer, self.config) 27 | return { 28 | "optimizer": optimizer, 29 | "lr_scheduler": scheduler, 30 | "monitor": "val_loss", 31 | } 32 | 33 | def step(self, batch, mode="train"): 34 | inputs, targets = batch 35 | preds = self.net(inputs) 36 | loss = self.criterion(preds, targets) 37 | self.log(f"{mode}_loss", loss, prog_bar=True) 38 | 39 | if mode == "val": 40 | preds_ = preds.float().detach() 41 | for metric_name in self.config["learner"]["val_metrics"]: 42 | MetricClass = self.__getattr__(f"{mode}_{metric_name}") 43 | if MetricClass is not None: 44 | self.log( 45 | f"{mode}_{metric_name}", 46 | MetricClass(preds_, targets), 47 | prog_bar=True, 48 | ) 49 | return loss 50 | 51 | def training_step(self, batch, batch_idx): 52 | return self.step(batch, mode="train") 53 | 54 | def validation_step(self, batch, batch_idx): 55 | return self.step(batch, mode="val") 56 | 57 | def test_step(self, batch, batch_idx): 58 | inputs, targets = batch 59 | preds = self.net(inputs) 60 | for metric_name in self.config["learner"]["test_metrics"]: 61 | metric = self.__getattr__(f"test_{metric_name}") 62 | self.log( 63 | f"test_{metric_name}", 64 | metric(preds.float(), targets), 65 | on_step=False, 66 | on_epoch=True, 67 | prog_bar=True, 68 | ) 69 | 70 | def _set_metrics(self, config): 71 | """ 72 | Set TorchMetrics as attributes in a dynamical manner. 73 | For instance, `self.train_accuracy = torchmetrics.Accuracy()` 74 | """ 75 | for metric_name in config["learner"]["val_metrics"]: 76 | # self.__setattr__(f"train_{metric_name}", get_metrics(metric_name, config)) 77 | self.__setattr__(f"val_{metric_name}", get_metrics(metric_name, config)) 78 | 79 | for metric_name in config["learner"]["test_metrics"]: 80 | self.__setattr__(f"test_{metric_name}", get_metrics(metric_name, config)) 81 | -------------------------------------------------------------------------------- /src/benchmark/reconnet/learner.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytorch_lightning as pl 3 | from .net import ReconNet 4 | from engine.dispatcher import get_scheduler, get_criterion, get_metrics 5 | 6 | 7 | class ReconNetLearner(pl.LightningModule): 8 | def __init__(self, config): 9 | super().__init__() 10 | y_dim = int(config["sampling_ratio"] * config["img_dim"] ** 2) 11 | self.net = ReconNet(y_dim, config["img_dim"]) 12 | self.config = config 13 | self.criterion = get_criterion(config) 14 | self._set_metrics(config) 15 | self.save_hyperparameters(config) 16 | 17 | def forward(self, inputs): 18 | return self.net(inputs) 19 | 20 | def configure_optimizers(self): 21 | trainable_params = filter(lambda p: p.requires_grad, self.net.parameters()) 22 | optimizer = torch.optim.Adam( 23 | trainable_params, 24 | lr=self.config["learner"]["lr"], 25 | ) 26 | scheduler = get_scheduler(optimizer, self.config) 27 | return { 28 | "optimizer": optimizer, 29 | "lr_scheduler": scheduler, 30 | "monitor": "val_loss", 31 | } 32 | 33 | def step(self, batch, mode="train"): 34 | inputs, targets = batch 35 | preds = self.net(inputs) 36 | loss = self.criterion(preds, targets) 37 | 38 | self.log(f"{mode}_loss", loss, prog_bar=True) 39 | 40 | preds_ = preds.float().detach() 41 | 42 | # Log validation metrics 43 | if mode == "val": 44 | for metric_name in self.config["learner"]["val_metrics"]: 45 | metric = self.__getattr__(f"{mode}_{metric_name}") 46 | self.log( 47 | f"{mode}_{metric_name}", 48 | metric(preds_, targets), 49 | prog_bar=True, 50 | ) 51 | return loss 52 | 53 | def training_step(self, batch, batch_idx): 54 | return self.step(batch, mode="train") 55 | 56 | def validation_step(self, batch, batch_idx): 57 | return self.step(batch, mode="val") 58 | 59 | def test_step(self, batch, batch_idx): 60 | inputs, targets = batch 61 | preds = self.net(inputs) 62 | for metric_name in self.config["learner"]["test_metrics"]: 63 | metric = self.__getattr__(f"test_{metric_name}") 64 | self.log( 65 | f"test_{metric_name}", 66 | metric(preds.float(), targets), 67 | prog_bar=True, 68 | on_step=False, 69 | on_epoch=True, 70 | ) 71 | 72 | def _set_metrics(self, config): 73 | """ 74 | Set TorchMetrics as attributes in a dynamical manner. 75 | For instance, `self.train_accuracy = torchmetrics.Accuracy()` 76 | """ 77 | for metric_name in config["learner"]["val_metrics"]: 78 | # self.__setattr__(f"train_{metric_name}", get_metrics(metric_name, config)) 79 | self.__setattr__(f"val_{metric_name}", get_metrics(metric_name, config)) 80 | 81 | for metric_name in config["learner"]["test_metrics"]: 82 | self.__setattr__(f"test_{metric_name}", get_metrics(metric_name, config)) 83 | -------------------------------------------------------------------------------- /src/benchmark/scsnet/learner.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytorch_lightning as pl 3 | from .net import SCSNetInit, SCSNetDeep 4 | from engine.dispatcher import get_scheduler, get_criterion, get_metrics 5 | 6 | 7 | class SCSNetLearner(pl.LightningModule): 8 | def __init__(self, config): 9 | super().__init__() 10 | in_channels = int(config["sampling_ratio"] * 16) 11 | self.net1 = SCSNetInit(in_channels) 12 | self.net2 = SCSNetDeep() 13 | self.config = config 14 | self.criterion = get_criterion(config) 15 | self._set_metrics(config) 16 | self.save_hyperparameters(config) 17 | 18 | def forward(self, inputs): 19 | return self.net2(self.net1(inputs)) 20 | 21 | def configure_optimizers(self): 22 | optimizer = torch.optim.Adam( 23 | self.parameters(), 24 | lr=self.config["learner"]["lr"], 25 | ) 26 | scheduler = get_scheduler(optimizer, self.config) 27 | return { 28 | "optimizer": optimizer, 29 | "lr_scheduler": scheduler, 30 | "monitor": "val_loss", 31 | } 32 | 33 | def step(self, batch, mode="train"): 34 | inputs, targets = batch 35 | 36 | preds1 = self.net1(inputs) 37 | preds2 = self.net2(preds1) 38 | 39 | loss1 = self.criterion(preds1, targets) 40 | loss2 = self.criterion(preds2, targets) 41 | loss = loss1 + loss2 42 | 43 | preds_ = preds2.float().detach() 44 | targets_ = targets.detach() 45 | 46 | # Log validation metrics 47 | if mode == "val": 48 | self.log(f"{mode}_loss", loss2, prog_bar=True) 49 | for metric_name in self.config["learner"]["val_metrics"]: 50 | metric = self.__getattr__(f"{mode}_{metric_name}") 51 | self.log( 52 | f"{mode}_{metric_name}", 53 | metric(preds_, targets_), 54 | prog_bar=True, 55 | ) 56 | return loss 57 | 58 | def training_step(self, batch, batch_idx): 59 | return self.step(batch, mode="train") 60 | 61 | def validation_step(self, batch, batch_idx): 62 | return self.step(batch, mode="val") 63 | 64 | def test_step(self, batch, batch_idx): 65 | inputs, targets = batch 66 | preds1 = self.net1(inputs) 67 | preds2 = self.net2(preds1).float() 68 | for metric_name in self.config["learner"]["test_metrics"]: 69 | metric = self.__getattr__(f"test_{metric_name}") 70 | self.log( 71 | f"test_{metric_name}", 72 | metric(preds2.float(), targets), 73 | on_step=False, 74 | on_epoch=True, 75 | prog_bar=True, 76 | ) 77 | 78 | def _set_metrics(self, config): 79 | """ 80 | Set TorchMetrics as attributes in a dynamical manner. 81 | For instance, `self.train_accuracy = torchmetrics.Accuracy()` 82 | """ 83 | for metric_name in config["learner"]["val_metrics"]: 84 | # self.__setattr__(f"train_{metric_name}", get_metrics(metric_name, config)) 85 | self.__setattr__(f"val_{metric_name}", get_metrics(metric_name, config)) 86 | 87 | for metric_name in config["learner"]["test_metrics"]: 88 | self.__setattr__(f"test_{metric_name}", get_metrics(metric_name, config)) 89 | -------------------------------------------------------------------------------- /src/benchmark/reconnet/inference_simulated.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import numpy as np 4 | from pathlib import Path 5 | import warnings 6 | import cv2 7 | import scipy.ndimage 8 | from data.emnist import EMNISTDataModule 9 | from data.svhn import SVHNDataModule 10 | from data.stl10 import STL10DataModule 11 | from .learner import ReconNetLearner 12 | from utils import load_config, create_patches 13 | 14 | 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument( 17 | "-d", "--dataset", type=str, required=True, help="'EMNIST', 'SVHN', or 'STL10'" 18 | ) 19 | parser.add_argument( 20 | "-s", 21 | "--sampling_ratio", 22 | type=float, 23 | required=True, 24 | help="Sampling ratio in percentage", 25 | ) 26 | args = parser.parse_args() 27 | 28 | 29 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" 30 | warnings.simplefilter("ignore") 31 | 32 | 33 | def run(): 34 | dataset = args.dataset 35 | sr = args.sampling_ratio 36 | 37 | checkpoint_folder = f"../logs/ReconNet_STL10_{int(sr * 100):04d}/best" 38 | 39 | if not os.path.exists(checkpoint_folder): 40 | run_name = os.listdir(Path(checkpoint_folder).parent)[0] 41 | checkpoint_path = ( 42 | f"{Path(checkpoint_folder).parent}/{run_name}/checkpoints/last.ckpt" 43 | ) 44 | message = ( 45 | f"The checkpoint from the run '{run_name}' is selected by default." 46 | + "If this is not intended, change the name of the preferred checkpoint folder to 'best'." 47 | ) 48 | print(message) 49 | else: 50 | checkpoint_path = f"{checkpoint_folder}/checkpoints/last.ckpt" 51 | 52 | train_config_path = os.path.join( 53 | Path(checkpoint_path).parent.parent, "hparams.yaml" 54 | ) 55 | train_config = load_config(train_config_path) 56 | train_config["sampling_ratio"] = sr / 100 57 | train_config["img_dim"] = 32 58 | 59 | if dataset == "EMNIST": 60 | data_module = EMNISTDataModule(train_config) 61 | elif dataset == "SVHN": 62 | data_module = SVHNDataModule(train_config) 63 | elif dataset == "STL10": 64 | data_module = STL10DataModule(train_config, reconnet=True) 65 | 66 | learner = ReconNetLearner.load_from_checkpoint( 67 | checkpoint_path=checkpoint_path, config=train_config, strict=False 68 | ) 69 | 70 | message = f"Inference: ReconNet on {dataset} dataset. Sampling ratio = {train_config['sampling_ratio']}" 71 | print(message) 72 | 73 | data_module.setup() 74 | ds = data_module.test_dataset 75 | 76 | directory = f"../inference_images/ReconNet/{dataset}/{int(sr * 100):04d}" 77 | os.makedirs(directory, exist_ok=True) 78 | 79 | if dataset != "STL10": 80 | for i in np.linspace(0, len(ds) - 1, 30, dtype=int): 81 | input = ds[i][0].unsqueeze(0) 82 | out = learner(input).squeeze().squeeze().detach().numpy() 83 | out = scipy.ndimage.zoom(out, 8, order=0, mode="nearest") 84 | cv2.imwrite(f"{directory}/{i}.png", out * 255) 85 | 86 | else: 87 | if not os.path.exists("../input/STL10/test_images_32x32"): 88 | create_patches("../input/STL10", size=32) 89 | for i in range(10): 90 | combined = np.zeros((96, 96), dtype=int) 91 | for j in range(9): 92 | input = ds[i * 9 + j][0].unsqueeze(0) 93 | out = learner(input).squeeze().squeeze().detach().numpy() 94 | out = (out * 255).astype(int) 95 | x = j // 3 * 32 96 | y = j % 3 * 32 97 | combined[x : x + 32, y : y + 32] = out 98 | combined = scipy.ndimage.zoom(combined, 8, order=0, mode="nearest") 99 | cv2.imwrite(f"{directory}/{i}.png", combined) 100 | 101 | print("Done.") 102 | 103 | 104 | if __name__ == "__main__": 105 | run() 106 | -------------------------------------------------------------------------------- /src/data/svhn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | from torch.utils.data.sampler import SubsetRandomSampler 4 | import torchvision 5 | from torchvision import transforms 6 | import pytorch_lightning as pl 7 | from .base_dataset import BaseDataset 8 | 9 | 10 | class SVHNDataset(BaseDataset): 11 | def __init__(self, sampling_ratio: float, bcs: bool, tfms=None, train=True): 12 | super().__init__(sampling_ratio=sampling_ratio, bcs=bcs) 13 | self.data = torchvision.datasets.SVHN( 14 | "../input/SVHN", 15 | split="train" if train else "test", 16 | transform=tfms, 17 | download=True, 18 | ) 19 | 20 | def __getitem__(self, idx): 21 | image, _ = self.data[idx] 22 | image_ = image.unsqueeze(dim=1) 23 | y = self.cs_operator(image_) 24 | return y.squeeze(dim=0), image 25 | 26 | def __len__(self): 27 | return len(self.data) 28 | 29 | 30 | class SVHNDataModule(pl.LightningDataModule): 31 | def __init__(self, config): 32 | super().__init__() 33 | self.config = config 34 | self.dm_config = config["data_module"] 35 | 36 | def setup(self, stage=None): 37 | train_tfms = transforms.Compose( 38 | [ 39 | transforms.Resize(32), 40 | transforms.RandomHorizontalFlip(p=0.5), 41 | transforms.ColorJitter(brightness=0.2, contrast=0.2), 42 | transforms.Grayscale(), 43 | transforms.ToTensor(), 44 | ] 45 | ) 46 | val_tfms = transforms.Compose( 47 | [ 48 | transforms.Resize(32), 49 | transforms.Grayscale(), 50 | transforms.ToTensor(), 51 | ] 52 | ) 53 | 54 | self.train_dataset = SVHNDataset( 55 | sampling_ratio=self.config["sampling_ratio"], 56 | bcs=self.config["bcs"], 57 | tfms=train_tfms, 58 | train=True, 59 | ) 60 | 61 | self.val_dataset = SVHNDataset( 62 | sampling_ratio=self.config["sampling_ratio"], 63 | bcs=self.config["bcs"], 64 | tfms=val_tfms, 65 | train=True, 66 | ) 67 | 68 | self.test_dataset = SVHNDataset( 69 | sampling_ratio=self.config["sampling_ratio"], 70 | bcs=self.config["bcs"], 71 | tfms=val_tfms, 72 | train=False, 73 | ) 74 | 75 | dataset_size = len(self.train_dataset) 76 | indices = torch.randperm(dataset_size) 77 | split = int(self.dm_config["val_percent"] * dataset_size) 78 | self.train_idx, self.valid_idx = indices[split:], indices[:split] 79 | 80 | def train_dataloader(self): 81 | train_sampler = SubsetRandomSampler(self.train_idx) 82 | return DataLoader( 83 | self.train_dataset, 84 | batch_size=self.dm_config["batch_size"], 85 | sampler=train_sampler, 86 | num_workers=self.dm_config["num_workers"], 87 | ) 88 | 89 | def val_dataloader(self): 90 | val_sampler = SubsetRandomSampler(self.valid_idx) 91 | return DataLoader( 92 | self.val_dataset, 93 | batch_size=self.dm_config["batch_size"], 94 | sampler=val_sampler, 95 | num_workers=self.dm_config["num_workers"], 96 | ) 97 | 98 | def test_dataloader(self): 99 | return DataLoader( 100 | self.test_dataset, 101 | batch_size=self.dm_config["batch_size"], 102 | num_workers=self.dm_config["num_workers"], 103 | ) 104 | 105 | def predict_dataloader(self): 106 | return DataLoader( 107 | self.test_dataset, 108 | batch_size=self.dm_config["batch_size"], 109 | num_workers=self.dm_config["num_workers"], 110 | ) 111 | -------------------------------------------------------------------------------- /src/data/emnist.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | from torch.utils.data.sampler import SubsetRandomSampler 4 | import torchvision 5 | from torchvision import transforms 6 | import pytorch_lightning as pl 7 | from .base_dataset import BaseDataset 8 | 9 | 10 | class EMNISTDataset(BaseDataset): 11 | def __init__(self, sampling_ratio: float, bcs: bool, tfms=None, train=True): 12 | super().__init__(sampling_ratio=sampling_ratio, bcs=bcs) 13 | self.data = torchvision.datasets.EMNIST( 14 | "../input", 15 | split="letters", # we use letters only 16 | train=train, 17 | transform=tfms, 18 | download=True, 19 | ) 20 | 21 | def __getitem__(self, idx): 22 | image, _ = self.data[idx] 23 | image_ = image.unsqueeze(dim=1) 24 | y = self.cs_operator(image_) 25 | return y.squeeze(dim=0), image 26 | 27 | def __len__(self): 28 | return len(self.data) 29 | 30 | 31 | class EMNISTDataModule(pl.LightningDataModule): 32 | def __init__(self, config): 33 | super().__init__() 34 | self.config = config 35 | self.dm_config = config["data_module"] 36 | 37 | def setup(self, stage=None): 38 | train_tfms = transforms.Compose( 39 | [ 40 | transforms.Resize(32), 41 | transforms.RandomHorizontalFlip(p=0.5), 42 | transforms.ColorJitter(brightness=0.2, contrast=0.2), 43 | transforms.Grayscale(), 44 | transforms.ToTensor(), 45 | ] 46 | ) 47 | val_tfms = transforms.Compose( 48 | [ 49 | transforms.Resize(32), 50 | transforms.Grayscale(), 51 | transforms.ToTensor(), 52 | ] 53 | ) 54 | 55 | self.train_dataset = EMNISTDataset( 56 | sampling_ratio=self.config["sampling_ratio"], 57 | bcs=self.config["bcs"], 58 | tfms=train_tfms, 59 | train=True, 60 | ) 61 | 62 | self.val_dataset = EMNISTDataset( 63 | sampling_ratio=self.config["sampling_ratio"], 64 | bcs=self.config["bcs"], 65 | tfms=val_tfms, 66 | train=True, 67 | ) 68 | 69 | self.test_dataset = EMNISTDataset( 70 | sampling_ratio=self.config["sampling_ratio"], 71 | bcs=self.config["bcs"], 72 | tfms=val_tfms, 73 | train=False, 74 | ) 75 | 76 | dataset_size = len(self.train_dataset) 77 | indices = torch.randperm(dataset_size) 78 | split = int(self.dm_config["val_percent"] * dataset_size) 79 | self.train_idx, self.valid_idx = indices[split:], indices[:split] 80 | 81 | def train_dataloader(self): 82 | train_sampler = SubsetRandomSampler(self.train_idx) 83 | return DataLoader( 84 | self.train_dataset, 85 | batch_size=self.dm_config["batch_size"], 86 | sampler=train_sampler, 87 | num_workers=self.dm_config["num_workers"], 88 | ) 89 | 90 | def val_dataloader(self): 91 | val_sampler = SubsetRandomSampler(self.valid_idx) 92 | return DataLoader( 93 | self.val_dataset, 94 | batch_size=self.dm_config["batch_size"], 95 | sampler=val_sampler, 96 | num_workers=self.dm_config["num_workers"], 97 | ) 98 | 99 | def test_dataloader(self): 100 | return DataLoader( 101 | self.test_dataset, 102 | batch_size=self.dm_config["batch_size"], 103 | num_workers=self.dm_config["num_workers"], 104 | ) 105 | 106 | def predict_dataloader(self): 107 | return DataLoader( 108 | self.test_dataset, 109 | batch_size=self.dm_config["batch_size"], 110 | num_workers=self.dm_config["num_workers"], 111 | ) 112 | -------------------------------------------------------------------------------- /src/benchmark/reconnet/inference_spi.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | import time 4 | import argparse 5 | import warnings 6 | import numpy as np 7 | import cv2 8 | import scipy.ndimage 9 | import scipy.io 10 | import math 11 | import torch 12 | import pytorch_lightning as pl 13 | from .learner import ReconNetLearner 14 | from utils import voltage2pixel, load_config 15 | 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument( 19 | "-s", 20 | "--sampling_ratio", 21 | type=float, 22 | required=True, 23 | help="Sampling ratio in percentage", 24 | ) 25 | args = parser.parse_args() 26 | 27 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" 28 | warnings.simplefilter("ignore") 29 | 30 | 31 | def setup(): 32 | inference_config = load_config("../config/inference_config.yaml") 33 | sr = args.sampling_ratio 34 | 35 | checkpoint_folder = f"../logs/ReconNet_STL10_{int(sr * 100):04d}/best" 36 | 37 | if not os.path.exists(checkpoint_folder): 38 | run_name = os.listdir(Path(checkpoint_folder).parent)[-1] 39 | checkpoint_path = ( 40 | f"{Path(checkpoint_folder).parent}/{run_name}/checkpoints/last.ckpt" 41 | ) 42 | message = ( 43 | f"The checkpoint from the run '{run_name}' is selected by default. " 44 | + "If this is not intended, change the name of the preferred checkpoint folder to 'best'." 45 | ) 46 | print(message) 47 | else: 48 | checkpoint_path = f"{checkpoint_folder}/checkpoints/last.ckpt" 49 | 50 | learner = ReconNetLearner.load_from_checkpoint(checkpoint_path=checkpoint_path) 51 | 52 | trainer = pl.Trainer( 53 | gpus=1 if inference_config["gpu"] else 0, 54 | logger=False, 55 | default_root_dir="../", 56 | ) 57 | return learner, trainer 58 | 59 | 60 | class RealDataset: 61 | def __init__(self, sampling_ratio, inference_config): 62 | self.real_data = inference_config["real_data"] 63 | self.phi = np.load(inference_config["measurement_matrix"]) 64 | self.c = int(sampling_ratio / 100 * 16) 65 | 66 | def __getitem__(self, idx): 67 | real_data = self.real_data[idx] 68 | path = os.path.join("../inference_input", real_data["filename"]) 69 | y_input = scipy.io.loadmat(path)["y"] 70 | 71 | y_input = y_input[ 72 | np.mod(np.arange(len(y_input)), len(y_input) // 64) < self.c 73 | ] # discard extra measurements 74 | 75 | y_input = torch.FloatTensor(y_input).permute(1, 0) 76 | y_input -= y_input.min() 77 | y_input /= real_data["max"] 78 | 79 | # Permute is necessary because during sampling, we used "channel-last" format. 80 | # Hence, we need to permute it to become channel-first to match PyTorch "channel-first" format 81 | y_input = y_input.view(-1, self.c) 82 | y_input = y_input.permute(1, 0).contiguous() 83 | y_input = y_input.view( 84 | -1, int(math.sqrt(y_input.shape[-1])), int(math.sqrt(y_input.shape[-1])) 85 | ) 86 | 87 | y_input = voltage2pixel( 88 | y_input, self.phi[: self.c], real_data["min"], real_data["max"] 89 | ) 90 | return y_input 91 | 92 | def __len__(self): 93 | return len(self.real_data) 94 | 95 | 96 | def deploy(learner): 97 | """Real experimental data""" 98 | inference_config = load_config("../config/inference_config.yaml") 99 | sr = args.sampling_ratio 100 | directory = f"../inference_images/ReconNet/SPI/{int(sr * 100):04d}" 101 | os.makedirs(directory, exist_ok=True) 102 | real_dataset = RealDataset(sr, inference_config) 103 | for x in real_dataset: 104 | prediction = learner(x.unsqueeze(0)) 105 | prediction = prediction.squeeze().squeeze().cpu().detach().numpy() 106 | prediction = scipy.ndimage.zoom(prediction, 4, order=0, mode="nearest") 107 | cv2.imwrite(f"{directory}/{time.time()}.png", prediction * 255) 108 | print("Finished reconstructing SPI images.") 109 | 110 | 111 | if __name__ == "__main__": 112 | learner, trainer = setup() 113 | deploy(learner) 114 | -------------------------------------------------------------------------------- /src/benchmark/scsnet/inference_spi.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | import time 4 | import warnings 5 | import numpy as np 6 | import cv2 7 | import argparse 8 | import scipy.ndimage 9 | import scipy.io 10 | import math 11 | import torch 12 | import pytorch_lightning as pl 13 | from .learner import SCSNetLearner 14 | from utils import voltage2pixel, load_config 15 | 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument( 19 | "-s", 20 | "--sampling_ratio", 21 | type=float, 22 | required=True, 23 | help="Sampling ratio in percentage", 24 | ) 25 | args = parser.parse_args() 26 | 27 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" 28 | warnings.simplefilter("ignore") 29 | 30 | 31 | def setup(): 32 | inference_config = load_config("../config/inference_config.yaml") 33 | # ds = inference_config["dataset"] 34 | sr = inference_config["sampling_ratio"] 35 | checkpoint_folder = f"../logs/SCSNet_STL10_{int(sr * 100):04d}/best" 36 | 37 | if not os.path.exists(checkpoint_folder): 38 | run_name = os.listdir(Path(checkpoint_folder).parent)[0] 39 | checkpoint_path = ( 40 | f"{Path(checkpoint_folder).parent}/{run_name}/checkpoints/last.ckpt" 41 | ) 42 | print( 43 | f"The checkpoint from the run '{run_name}' is selected by default. \ 44 | If this is not intended, change the name of the preferred checkpoint folder to 'best'." 45 | ) 46 | else: 47 | checkpoint_path = f"{checkpoint_folder}/checkpoints/last.ckpt" 48 | 49 | learner = SCSNetLearner.load_from_checkpoint(checkpoint_path=checkpoint_path) 50 | 51 | trainer = pl.Trainer( 52 | gpus=1 if inference_config["gpu"] else 0, 53 | logger=False, 54 | default_root_dir="../", 55 | ) 56 | return learner, trainer 57 | 58 | 59 | class RealDataset: 60 | def __init__(self, sampling_ratio, inference_config): 61 | self.real_data = inference_config["real_data"] 62 | self.phi = np.load(inference_config["measurement_matrix"]) 63 | self.c = int(sampling_ratio / 100 * 16) 64 | 65 | def __getitem__(self, idx): 66 | real_data = self.real_data[idx] 67 | path = os.path.join("../inference_input", real_data["filename"]) 68 | y_input = scipy.io.loadmat(path)["y"] 69 | 70 | y_input = y_input[ 71 | np.mod(np.arange(len(y_input)), len(y_input) // 64) < self.c 72 | ] # discard extra measurements 73 | 74 | y_input = torch.FloatTensor(y_input).permute(1, 0) 75 | y_input -= y_input.min() 76 | y_input /= real_data["max"] 77 | 78 | # Permute is necessary because during sampling, we used "channel-last" format. 79 | # Hence, we need to permute it to become channel-first to match PyTorch "channel-first" format 80 | y_input = y_input.view(-1, self.c) 81 | y_input = y_input.permute(1, 0).contiguous() 82 | y_input = y_input.view( 83 | -1, int(math.sqrt(y_input.shape[-1])), int(math.sqrt(y_input.shape[-1])) 84 | ) 85 | 86 | y_input = voltage2pixel( 87 | y_input, self.phi[: self.c], real_data["min"], real_data["max"] 88 | ) 89 | return y_input 90 | 91 | def __len__(self): 92 | return len(self.real_data) 93 | 94 | 95 | def predict_one(): 96 | # TODO: for standard dataset test set. 97 | return 98 | 99 | 100 | def deploy(learner): 101 | """Real experimental data""" 102 | inference_config = load_config("../config/inference_config.yaml") 103 | sr = args.sampling_ratio 104 | directory = f"../inference_images/SCSNet/SPI/{int(sr * 100):04d}" 105 | os.makedirs(directory, exist_ok=True) 106 | real_dataset = RealDataset(sr, inference_config) 107 | for x in real_dataset: 108 | prediction = learner(x.unsqueeze(0)) 109 | prediction = prediction.squeeze().squeeze().cpu().detach().numpy() 110 | prediction = scipy.ndimage.zoom(prediction, 4, order=0, mode="nearest") 111 | cv2.imwrite(f"{directory}/{time.time()}.png", prediction * 255) 112 | print("Finished reconstructing SPI images.") 113 | 114 | 115 | if __name__ == "__main__": 116 | learner, trainer = setup() 117 | deploy(learner) 118 | -------------------------------------------------------------------------------- /src/inference_spi.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | import time 4 | import argparse 5 | import warnings 6 | import numpy as np 7 | import cv2 8 | import scipy.ndimage 9 | import scipy.io 10 | import math 11 | import torch 12 | import pytorch_lightning as pl 13 | from engine.learner import BCSUNetLearner 14 | from utils import voltage2pixel, load_config 15 | 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument( 19 | "-d", "--dataset", type=str, required=True, help="'EMNIST', 'SVHN', or 'STL10'" 20 | ) 21 | parser.add_argument( 22 | "-s", 23 | "--sampling_ratio", 24 | type=float, 25 | required=True, 26 | help="Sampling ratio in percentage", 27 | ) 28 | args = parser.parse_args() 29 | 30 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" 31 | warnings.simplefilter("ignore") 32 | 33 | 34 | def setup(): 35 | inference_config = load_config("../config/inference_config.yaml") 36 | ds = args.dataset 37 | sr = args.sampling_ratio 38 | 39 | checkpoint_folder = f"../logs/BCSUNet_{ds}_{int(sr * 100):04d}/best" 40 | 41 | if not os.path.exists(checkpoint_folder): 42 | run_name = os.listdir(Path(checkpoint_folder).parent)[0] 43 | checkpoint_path = ( 44 | f"{Path(checkpoint_folder).parent}/{run_name}/checkpoints/best.ckpt" 45 | ) 46 | message = ( 47 | f"The checkpoint from the run '{run_name}' is selected by default." 48 | + "If this is not intended, change the name of the preferred checkpoint folder to 'best'." 49 | ) 50 | print(message) 51 | else: 52 | checkpoint_path = f"{checkpoint_folder}/checkpoints/best.ckpt" 53 | 54 | learner = BCSUNetLearner.load_from_checkpoint(checkpoint_path=checkpoint_path) 55 | 56 | trainer = pl.Trainer( 57 | gpus=1 if inference_config["gpu"] else 0, 58 | logger=False, 59 | default_root_dir="../", 60 | ) 61 | return learner, trainer 62 | 63 | 64 | class RealDataset: 65 | def __init__(self, sampling_ratio, inference_config): 66 | self.real_data = inference_config["real_data"] 67 | self.phi = np.load(inference_config["measurement_matrix"]) 68 | self.c = int(sampling_ratio / 100 * 16) 69 | 70 | def __getitem__(self, idx): 71 | real_data = self.real_data[idx] 72 | path = os.path.join("../inference_input", real_data["filename"]) 73 | y_input = scipy.io.loadmat(path)["y"] 74 | 75 | y_input = y_input[ 76 | np.mod(np.arange(len(y_input)), len(y_input) // 64) < self.c 77 | ] # discard extra measurements 78 | 79 | y_input = torch.FloatTensor(y_input).permute(1, 0) 80 | y_input -= y_input.min() 81 | y_input /= real_data["max"] 82 | 83 | # Permute is necessary because during sampling, we used "channel-last" format. 84 | # Hence, we need to permute it to become channel-first to match PyTorch "channel-first" format 85 | y_input = y_input.view(-1, self.c) 86 | y_input = y_input.permute(1, 0).contiguous() 87 | y_input = y_input.view( 88 | -1, int(math.sqrt(y_input.shape[-1])), int(math.sqrt(y_input.shape[-1])) 89 | ) 90 | 91 | y_input = voltage2pixel( 92 | y_input, self.phi[: self.c], real_data["min"], real_data["max"] 93 | ) 94 | return y_input 95 | 96 | def __len__(self): 97 | return len(self.real_data) 98 | 99 | 100 | def deploy(learner): 101 | """Real experimental data""" 102 | inference_config = load_config("../config/inference_config.yaml") 103 | sr = args.sampling_ratio 104 | real_dataset = RealDataset(sr, inference_config) 105 | save_dir = f"../inference_images/BCSUNet/SPI/{int(sr * 100):04d}_{args.dataset}" 106 | os.makedirs(save_dir, exist_ok=True) 107 | 108 | for x in real_dataset: 109 | prediction = learner(x.unsqueeze(0)) 110 | prediction = prediction.squeeze().squeeze().cpu().detach().numpy() 111 | prediction = scipy.ndimage.zoom(prediction, 8, order=0, mode="nearest") 112 | cv2.imwrite(f"{save_dir}/{time.time()}.png", prediction * 255) 113 | 114 | print("Finished reconstructing SPI images.") 115 | 116 | 117 | if __name__ == "__main__": 118 | learner, trainer = setup() 119 | deploy(learner) 120 | -------------------------------------------------------------------------------- /src/model/spectral_norm.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of spectral normalization for GANs. 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | class SpectralNorm: 10 | r""" 11 | Spectral Normalization for GANs (Miyato 2018). 12 | 13 | Inheritable class for performing spectral normalization of weights, 14 | as approximated using power iteration. 15 | 16 | Details: See Algorithm 1 of Appendix A (Miyato 2018). 17 | 18 | Attributes: 19 | n_dim (int): Number of dimensions. 20 | num_iters (int): Number of iterations for power iter. 21 | eps (float): Epsilon for zero division tolerance when normalizing. 22 | """ 23 | 24 | def __init__(self, n_dim, num_iters=1, eps=1e-12): 25 | self.num_iters = num_iters 26 | self.eps = eps 27 | 28 | # Register a singular vector for each sigma 29 | self.register_buffer("sn_u", torch.randn(1, n_dim)) 30 | self.register_buffer("sn_sigma", torch.ones(1)) 31 | 32 | @property 33 | def u(self): 34 | return getattr(self, "sn_u") 35 | 36 | @property 37 | def sigma(self): 38 | return getattr(self, "sn_sigma") 39 | 40 | def _power_iteration(self, W, u, num_iters, eps=1e-12): 41 | with torch.no_grad(): 42 | for _ in range(num_iters): 43 | v = F.normalize(torch.matmul(u, W), eps=eps) 44 | u = F.normalize(torch.matmul(v, W.t()), eps=eps) 45 | 46 | # Note: must have gradients, otherwise weights do not get updated! 47 | sigma = torch.mm(u, torch.mm(W, v.t())) 48 | 49 | return sigma, u, v 50 | 51 | def sn_weights(self): 52 | r""" 53 | Spectrally normalize current weights of the layer. 54 | """ 55 | W = self.weight.view(self.weight.shape[0], -1) 56 | 57 | # Power iteration 58 | sigma, u, v = self._power_iteration( 59 | W=W, u=self.u, num_iters=self.num_iters, eps=self.eps 60 | ) 61 | 62 | # Update only during training 63 | if self.training: 64 | with torch.no_grad(): 65 | self.sigma[:] = sigma 66 | self.u[:] = u 67 | 68 | return self.weight / sigma 69 | 70 | 71 | class SNConv2d(nn.Conv2d, SpectralNorm): 72 | r""" 73 | Spectrally normalized layer for Conv2d. 74 | 75 | Attributes: 76 | in_channels (int): Input channel dimension. 77 | out_channels (int): Output channel dimensions. 78 | """ 79 | 80 | def __init__(self, in_channels, out_channels, *args, **kwargs): 81 | nn.Conv2d.__init__(self, in_channels, out_channels, *args, **kwargs) 82 | 83 | SpectralNorm.__init__( 84 | self, n_dim=out_channels, num_iters=kwargs.get("num_iters", 1) 85 | ) 86 | 87 | def forward(self, x): 88 | return F.conv2d( 89 | input=x, 90 | weight=self.sn_weights(), 91 | bias=self.bias, 92 | stride=self.stride, 93 | padding=self.padding, 94 | dilation=self.dilation, 95 | groups=self.groups, 96 | ) 97 | 98 | 99 | class SNLinear(nn.Linear, SpectralNorm): 100 | r""" 101 | Spectrally normalized layer for Linear. 102 | 103 | Attributes: 104 | in_features (int): Input feature dimensions. 105 | out_features (int): Output feature dimensions. 106 | """ 107 | 108 | def __init__(self, in_features, out_features, *args, **kwargs): 109 | nn.Linear.__init__(self, in_features, out_features, *args, **kwargs) 110 | 111 | SpectralNorm.__init__( 112 | self, n_dim=out_features, num_iters=kwargs.get("num_iters", 1) 113 | ) 114 | 115 | def forward(self, x): 116 | return F.linear(input=x, weight=self.sn_weights(), bias=self.bias) 117 | 118 | 119 | class SNEmbedding(nn.Embedding, SpectralNorm): 120 | r""" 121 | Spectrally normalized layer for Embedding. 122 | 123 | Attributes: 124 | num_embeddings (int): Number of embeddings. 125 | embedding_dim (int): Dimensions of each embedding vector 126 | """ 127 | 128 | def __init__(self, num_embeddings, embedding_dim, *args, **kwargs): 129 | nn.Embedding.__init__(self, num_embeddings, embedding_dim, *args, **kwargs) 130 | 131 | SpectralNorm.__init__(self, n_dim=num_embeddings) 132 | 133 | def forward(self, x): 134 | return F.embedding(input=x, weight=self.sn_weights()) 135 | -------------------------------------------------------------------------------- /src/model/upsamplenet.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.nn import functional as F 3 | from model.layers import SNConv2d 4 | 5 | 6 | class ReshapeNet(nn.Module): 7 | """The "initial reconstruction network" of SCSNet""" 8 | 9 | def __init__(self, in_channels, block_size=4): 10 | super().__init__() 11 | self.block_size = block_size 12 | self.conv = nn.Conv2d(in_channels, block_size ** 2, kernel_size=1) 13 | 14 | def forward(self, x): 15 | x = self.conv(x) 16 | out = self._permute(x) 17 | return out 18 | 19 | def _permute(self, x): 20 | B, C, H, W = x.shape 21 | x = x.permute(0, 2, 3, 1) 22 | x = x.view(B, H, W, self.block_size, self.block_size) 23 | x = x.permute(0, 1, 3, 2, 4).contiguous() 24 | out = x.view(-1, 1, H * self.block_size, W * self.block_size) 25 | return out 26 | 27 | 28 | class UpsampleNet(nn.Module): 29 | def __init__(self, sampling_ratio, upsamplenet_config): 30 | super().__init__() 31 | kernel_size = 4 32 | first_out_channels = int(sampling_ratio * kernel_size ** 2) 33 | config = upsamplenet_config 34 | 35 | self.up1 = UpResBlock( 36 | in_channels=first_out_channels, 37 | out_channels=config["out_channels_1"], 38 | middle_channels=None, 39 | upsample=True, 40 | use_transpose_conv=config["use_transpose_conv"], 41 | spectral_norm=config["spectral_norm"], 42 | ) 43 | 44 | self.up2 = UpResBlock( 45 | in_channels=config["out_channels_1"], 46 | out_channels=config["out_channels_2"], 47 | middle_channels=None, 48 | upsample=True, 49 | use_transpose_conv=config["use_transpose_conv"], 50 | spectral_norm=config["spectral_norm"], 51 | ) 52 | 53 | def forward(self, x): 54 | x = self.up1(x) 55 | out = self.up2(x) # passed to UNet 56 | return out 57 | 58 | 59 | class UpResBlock(nn.Module): 60 | def __init__( 61 | self, 62 | in_channels, 63 | out_channels, 64 | middle_channels=None, 65 | upsample=True, 66 | use_transpose_conv=False, 67 | norm_type="instance", 68 | spectral_norm=True, 69 | init_type="xavier", 70 | ): 71 | 72 | super().__init__() 73 | self.upsample = upsample 74 | self.use_transpose_conv = use_transpose_conv 75 | 76 | if middle_channels is None: 77 | middle_channels = out_channels 78 | 79 | if use_transpose_conv: 80 | assert upsample is True 81 | self.conv1 = nn.ConvTranspose2d( 82 | in_channels, 83 | middle_channels, 84 | kernel_size=2, 85 | stride=2, 86 | padding=1, 87 | bias=False, 88 | ) 89 | self.conv2 = nn.ConvTranspose2d( 90 | middle_channels, 91 | out_channels, 92 | kernel_size=2, 93 | stride=2, 94 | padding=1, 95 | bias=False, 96 | ) 97 | 98 | else: # if transpose conv is not used. 99 | # The `_residual_block` method will decide whether or not it upsamples depending on `upsample == True/False` 100 | conv = SNConv2d if spectral_norm else nn.Conv2d 101 | self.conv1 = conv( 102 | in_channels, 103 | middle_channels, 104 | kernel_size=3, 105 | stride=1, 106 | padding=1, 107 | bias=False, 108 | ) 109 | self.conv2 = conv( 110 | middle_channels, 111 | out_channels, 112 | kernel_size=3, 113 | stride=1, 114 | padding=1, 115 | bias=False, 116 | ) 117 | 118 | self.bn1 = nn.BatchNorm2d(middle_channels) 119 | self.bn2 = nn.BatchNorm2d(out_channels) 120 | 121 | self.relu = nn.ReLU(inplace=True) 122 | 123 | def _upsample(self, x, conv_layer): 124 | if self.use_transpose_conv: 125 | return conv_layer(x) 126 | else: 127 | return conv_layer( 128 | F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=False) 129 | ) 130 | 131 | def _residual_block(self, x): 132 | x = self._upsample(x, self.conv1) if self.upsample else self.conv1(x) 133 | x = self.bn1(x) 134 | x = self.relu(x) 135 | x = self.conv2(x) 136 | x = self.bn2(x) 137 | out = self.relu(x) 138 | return out 139 | 140 | def _shortcut(self, x): 141 | return self._upsample(x, self.conv1) if self.upsample else self.conv1(x) 142 | 143 | def forward(self, x): 144 | return self._residual_block(x) + self._shortcut(x) 145 | -------------------------------------------------------------------------------- /src/data/stl10.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import torch 4 | from torch.utils.data import DataLoader 5 | from torch.utils.data.sampler import SubsetRandomSampler 6 | import torchvision 7 | from torchvision import transforms 8 | import pytorch_lightning as pl 9 | from .base_dataset import BaseDataset 10 | from utils import create_patches 11 | 12 | 13 | class STL10Dataset(BaseDataset): 14 | def __init__(self, sampling_ratio: float, bcs: bool, tfms=None, train=True): 15 | super().__init__(sampling_ratio=sampling_ratio, bcs=bcs) 16 | self.data = torchvision.datasets.STL10( 17 | "../input/STL10", 18 | split="unlabeled" if train else "test", 19 | transform=tfms, 20 | download=True, 21 | ) 22 | 23 | def __getitem__(self, idx): 24 | image, _ = self.data[idx] 25 | image_ = image.unsqueeze(dim=1) 26 | y = self.cs_operator(image_) 27 | return y.squeeze(dim=0), image 28 | 29 | def __len__(self): 30 | return len(self.data) 31 | 32 | 33 | class STL10ReconnetTestDataset(BaseDataset): 34 | """ 35 | The test dataset of ReconNet is unique because we need to reconstruct image patches and combine. 36 | This is because ReconNet can only be used on small images. 37 | To reproduce this, crop the 96x96 STL10 images into 32x32. 38 | """ 39 | 40 | def __init__(self, sampling_ratio: float, bcs: bool): 41 | super().__init__(sampling_ratio=sampling_ratio, bcs=bcs) 42 | self.root_dir = "../input/STL10" 43 | self.data_dir = os.path.join(self.root_dir, "test_images_32x32") 44 | if not os.path.exists(self.data_dir): 45 | create_patches(self.root_dir, size=32) 46 | self.filenames = os.listdir(self.data_dir) 47 | 48 | def __getitem__(self, idx): 49 | image = cv2.imread( 50 | os.path.join(self.data_dir, self.filenames[idx]), 51 | cv2.IMREAD_GRAYSCALE, 52 | ) 53 | image = torch.tensor(image) / 255.0 54 | image = image.unsqueeze(dim=0) 55 | image_ = image.unsqueeze(dim=1) 56 | y = self.cs_operator(image_) 57 | return y.squeeze(dim=0), image 58 | 59 | def __len__(self): 60 | return len(self.filenames) 61 | 62 | 63 | class STL10DataModule(pl.LightningDataModule): 64 | def __init__(self, config, reconnet=False): 65 | super().__init__() 66 | self.config = config 67 | self.dm_config = config["data_module"] 68 | self.reconnet = reconnet # whether or not ReconNet is the architecture. 69 | 70 | def setup(self, stage=None): 71 | train_tfms_list = [transforms.RandomCrop(32)] if self.reconnet else [] 72 | train_tfms_list += [ 73 | transforms.RandomHorizontalFlip(p=0.5), 74 | transforms.ColorJitter(brightness=0.2, contrast=0.2), 75 | transforms.Grayscale(), 76 | transforms.ToTensor(), 77 | ] 78 | train_tfms = transforms.Compose(train_tfms_list) 79 | 80 | val_tfms_list = [transforms.CenterCrop(32)] if self.reconnet else [] 81 | val_tfms_list += [ 82 | transforms.Grayscale(), 83 | transforms.ToTensor(), 84 | ] 85 | val_tfms = transforms.Compose(val_tfms_list) 86 | 87 | self.train_dataset = STL10Dataset( 88 | sampling_ratio=self.config["sampling_ratio"], 89 | bcs=self.config["bcs"], 90 | tfms=train_tfms, 91 | train=True, 92 | ) 93 | 94 | self.val_dataset = STL10Dataset( 95 | sampling_ratio=self.config["sampling_ratio"], 96 | bcs=self.config["bcs"], 97 | tfms=val_tfms, 98 | train=True, 99 | ) 100 | 101 | if self.reconnet: 102 | self.test_dataset = STL10ReconnetTestDataset( 103 | sampling_ratio=self.config["sampling_ratio"], 104 | bcs=self.config["bcs"], 105 | ) 106 | else: 107 | self.test_dataset = STL10Dataset( 108 | sampling_ratio=self.config["sampling_ratio"], 109 | bcs=self.config["bcs"], 110 | tfms=val_tfms, 111 | train=False, 112 | ) 113 | 114 | dataset_size = len(self.train_dataset) 115 | indices = torch.randperm(dataset_size) 116 | split = int(self.dm_config["val_percent"] * dataset_size) 117 | self.train_idx, self.valid_idx = indices[split:], indices[:split] 118 | 119 | def train_dataloader(self): 120 | train_sampler = SubsetRandomSampler(self.train_idx) 121 | return DataLoader( 122 | self.train_dataset, 123 | batch_size=self.dm_config["batch_size"], 124 | sampler=train_sampler, 125 | num_workers=self.dm_config["num_workers"], 126 | ) 127 | 128 | def val_dataloader(self): 129 | val_sampler = SubsetRandomSampler(self.valid_idx) 130 | return DataLoader( 131 | self.val_dataset, 132 | batch_size=self.dm_config["batch_size"], 133 | sampler=val_sampler, 134 | num_workers=self.dm_config["num_workers"], 135 | ) 136 | 137 | def test_dataloader(self): 138 | return DataLoader( 139 | self.test_dataset, 140 | batch_size=self.dm_config["batch_size"], 141 | num_workers=self.dm_config["num_workers"], 142 | ) 143 | 144 | def predict_dataloader(self): 145 | if self.reconnet: 146 | batch_size = 9 * 12 # 9 image patches per 96x96 image 147 | else: 148 | batch_size = self.dm_config["batch_size"] 149 | 150 | return DataLoader( 151 | self.test_dataset, 152 | batch_size=batch_size, 153 | num_workers=self.dm_config["num_workers"], 154 | ) 155 | -------------------------------------------------------------------------------- /src/model/layers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script for building specific layers needed by GAN architecture: 3 | https://github.com/kwotsin/mimicry/tree/master/torch_mimicry/modules/layers.py 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from model import spectral_norm 10 | 11 | 12 | class SelfAttention(nn.Module): 13 | """ 14 | Self-attention layer based on version used in BigGAN code: 15 | https://github.com/ajbrock/BigGAN-PyTorch/blob/master/layers.py 16 | """ 17 | 18 | def __init__(self, num_feat, spectral_norm=True): 19 | super().__init__() 20 | self.num_feat = num_feat 21 | self.spectral_norm = spectral_norm 22 | 23 | if self.spectral_norm: 24 | self.theta = SNConv2d( 25 | self.num_feat, self.num_feat >> 3, 1, 1, padding=0, bias=False 26 | ) 27 | self.phi = SNConv2d( 28 | self.num_feat, self.num_feat >> 3, 1, 1, padding=0, bias=False 29 | ) 30 | self.g = SNConv2d( 31 | self.num_feat, self.num_feat >> 1, 1, 1, padding=0, bias=False 32 | ) 33 | self.o = SNConv2d( 34 | self.num_feat >> 1, self.num_feat, 1, 1, padding=0, bias=False 35 | ) 36 | 37 | else: 38 | self.theta = nn.Conv2d( 39 | self.num_feat, self.num_feat >> 3, 1, 1, padding=0, bias=False 40 | ) 41 | self.phi = nn.Conv2d( 42 | self.num_feat, self.num_feat >> 3, 1, 1, padding=0, bias=False 43 | ) 44 | self.g = nn.Conv2d( 45 | self.num_feat, self.num_feat >> 1, 1, 1, padding=0, bias=False 46 | ) 47 | self.o = nn.Conv2d( 48 | self.num_feat >> 1, self.num_feat, 1, 1, padding=0, bias=False 49 | ) 50 | 51 | self.gamma = nn.Parameter(torch.tensor(0.0), requires_grad=True) 52 | 53 | def forward(self, x): 54 | """ 55 | Feedforward function. Implementation differs from actual SAGAN paper, 56 | see note from BigGAN: 57 | https://github.com/ajbrock/BigGAN-PyTorch/blob/master/layers.py#L142 58 | 59 | See official TF Implementation: 60 | https://github.com/brain-research/self-attention-gan/blob/master/non_local.py 61 | 62 | Args: 63 | x (Tensor): Input feature map. 64 | 65 | Returns: 66 | Tensor: Feature map weighed with attention map. 67 | """ 68 | N, C, H, W = x.shape 69 | location_num = H * W 70 | downsampled_num = location_num >> 2 71 | 72 | # Theta path 73 | theta = self.theta(x) 74 | theta = theta.view(N, C >> 3, location_num) # (N, C>>3, H*W) 75 | 76 | # Phi path 77 | phi = self.phi(x) 78 | phi = F.max_pool2d(phi, [2, 2], stride=2) 79 | phi = phi.view(N, C >> 3, downsampled_num) # (N, C>>3, H*W>>2) 80 | 81 | # Attention map 82 | attn = torch.bmm(theta.transpose(1, 2), phi) 83 | attn = F.softmax(attn, -1) # (N, H*W, H*W>>2) 84 | # print(torch.sum(attn, axis=2)) # (N, H*W) 85 | 86 | # Conv value 87 | g = self.g(x) 88 | g = F.max_pool2d(g, [2, 2], stride=2) 89 | g = g.view(N, C >> 1, downsampled_num) # (N, C>>1, H*W>>2) 90 | 91 | # Apply attention 92 | attn_g = torch.bmm(g, attn.transpose(1, 2)) # (N, C>>1, H*W) 93 | attn_g = attn_g.view(N, C >> 1, H, W) # (N, C>>1, H, W) 94 | 95 | # Project back feature size 96 | attn_g = self.o(attn_g) 97 | 98 | # Weigh attention map 99 | output = x + self.gamma * attn_g 100 | 101 | return output 102 | 103 | 104 | def SNConv2d(*args, default=True, **kwargs): 105 | r""" 106 | Wrapper for applying spectral norm on conv2d layer. 107 | """ 108 | if default: 109 | return nn.utils.spectral_norm(nn.Conv2d(*args, **kwargs)) 110 | 111 | else: 112 | return spectral_norm.SNConv2d(*args, **kwargs) 113 | 114 | 115 | def SNLinear(*args, default=True, **kwargs): 116 | r""" 117 | Wrapper for applying spectral norm on linear layer. 118 | """ 119 | if default: 120 | return nn.utils.spectral_norm(nn.Linear(*args, **kwargs)) 121 | 122 | else: 123 | return spectral_norm.SNLinear(*args, **kwargs) 124 | 125 | 126 | def SNEmbedding(*args, default=True, **kwargs): 127 | r""" 128 | Wrapper for applying spectral norm on embedding layer. 129 | """ 130 | if default: 131 | return nn.utils.spectral_norm(nn.Embedding(*args, **kwargs)) 132 | 133 | else: 134 | return spectral_norm.SNEmbedding(*args, **kwargs) 135 | 136 | 137 | class ConditionalBatchNorm2d(nn.Module): 138 | r""" 139 | Conditional Batch Norm as implemented in 140 | https://github.com/pytorch/pytorch/issues/8985 141 | 142 | Attributes: 143 | num_features (int): Size of feature map for batch norm. 144 | num_classes (int): Determines size of embedding layer to condition BN. 145 | """ 146 | 147 | def __init__(self, num_features, num_classes): 148 | super().__init__() 149 | self.num_features = num_features 150 | self.bn = nn.BatchNorm2d(num_features, affine=False) 151 | self.embed = nn.Embedding(num_classes, num_features * 2) 152 | self.embed.weight.data[:, :num_features].normal_( 153 | 1, 0.02 154 | ) # Initialise scale at N(1, 0.02) 155 | self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0 156 | 157 | def forward(self, x, y): 158 | r""" 159 | Feedforwards for conditional batch norm. 160 | 161 | Args: 162 | x (Tensor): Input feature map. 163 | y (Tensor): Input class labels for embedding. 164 | 165 | Returns: 166 | Tensor: Output feature map. 167 | """ 168 | out = self.bn(x) 169 | gamma, beta = self.embed(y).chunk( 170 | 2, 1 171 | ) # divide into 2 chunks, split from dim 1. 172 | out = gamma.view(-1, self.num_features, 1, 1) * out + beta.view( 173 | -1, self.num_features, 1, 1 174 | ) 175 | 176 | return out 177 | -------------------------------------------------------------------------------- /src/model/unet.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import torch 3 | from torch import nn 4 | 5 | 6 | class UNet(nn.Module): 7 | def __init__( 8 | self, 9 | config, 10 | input_nc=8, 11 | output_nc=1, 12 | num_downs=5, 13 | ): 14 | """ 15 | Construct a Unet generator 16 | Parameters: 17 | input_nc (int) -- the number of channels in input images 18 | output_nc (int) -- the number of channels in output images 19 | num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7, 20 | image of size 128x128 will become of size 1x1 # at the bottleneck 21 | channels (int) -- the number of filters in the last conv layer 22 | We construct the U-Net from the innermost layer to the outermost layer. 23 | It is a recursive process. 24 | """ 25 | super().__init__() 26 | channels = config["channels"] 27 | unet_block = UnetSkipConnectionBlock( 28 | channels * 8, 29 | channels * 8, 30 | input_nc=None, 31 | submodule=None, 32 | norm_layer=nn.BatchNorm2d, 33 | innermost=True, 34 | ) # add the innermost layer first 35 | 36 | # Add intermediate layers with ngf * 8 filters 37 | for i in range(num_downs - 5): 38 | unet_block = UnetSkipConnectionBlock( 39 | channels * 8, 40 | channels * 8, 41 | input_nc=None, 42 | submodule=unet_block, 43 | norm_layer=nn.BatchNorm2d, 44 | use_dropout=config["use_dropout"], 45 | ) 46 | 47 | # Gradually reduce the number of filters from `channels*8` to `channels` 48 | unet_block = UnetSkipConnectionBlock( 49 | channels * 4, 50 | channels * 8, 51 | input_nc=None, 52 | submodule=unet_block, 53 | norm_layer=nn.BatchNorm2d, 54 | ) 55 | unet_block = UnetSkipConnectionBlock( 56 | channels * 2, 57 | channels * 4, 58 | input_nc=None, 59 | submodule=unet_block, 60 | norm_layer=nn.BatchNorm2d, 61 | ) 62 | unet_block = UnetSkipConnectionBlock( 63 | channels, 64 | channels * 2, 65 | input_nc=None, 66 | submodule=unet_block, 67 | norm_layer=nn.BatchNorm2d, 68 | ) 69 | 70 | self.model = UnetSkipConnectionBlock( 71 | output_nc, 72 | channels, 73 | input_nc=input_nc, 74 | submodule=unet_block, 75 | outermost=True, 76 | norm_layer=nn.BatchNorm2d, 77 | ) # add the outermost layer 78 | 79 | def forward(self, input): 80 | return self.model(input) 81 | 82 | 83 | class UnetSkipConnectionBlock(nn.Module): 84 | """ 85 | Defines the Unet submodule with skip connection. 86 | X -------------------identity---------------------- 87 | |-- downsampling -- |submodule| -- upsampling --| 88 | 89 | """ 90 | 91 | def __init__( 92 | self, 93 | outer_nc, 94 | inner_nc, 95 | input_nc=None, 96 | submodule=None, 97 | outermost=False, 98 | innermost=False, 99 | norm_layer=nn.BatchNorm2d, 100 | use_dropout=False, 101 | ): 102 | """ 103 | Construct a Unet submodule with skip connections. 104 | 105 | Parameters: 106 | outer_nc (int) -- the number of filters in the outer conv layer 107 | inner_nc (int) -- the number of filters in the inner conv layer 108 | input_nc (int) -- the number of channels in input images/features 109 | submodule (UnetSkipConnectionBlock) -- previously defined submodules 110 | outermost (bool) -- if this module is the outermost module 111 | innermost (bool) -- if this module is the innermost module 112 | norm_layer -- normalization layer 113 | use_dropout (bool) -- if use dropout layers. 114 | """ 115 | 116 | # TODO: add spectral norm 117 | 118 | super().__init__() 119 | self.outermost = outermost 120 | use_bias = ( 121 | (norm_layer.func == nn.InstanceNorm2d) 122 | if type(norm_layer) == partial 123 | else (norm_layer == nn.InstanceNorm2d) 124 | ) # use_bias is False if batch norm is used 125 | if input_nc is None: 126 | input_nc = outer_nc 127 | downconv = nn.Conv2d( 128 | input_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias 129 | ) 130 | downrelu = nn.LeakyReLU(0.2, inplace=True) 131 | downnorm = norm_layer(inner_nc) 132 | uprelu = nn.ReLU(inplace=True) 133 | upnorm = norm_layer(outer_nc) 134 | 135 | if outermost: 136 | upconv = nn.ConvTranspose2d( 137 | inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=True 138 | ) 139 | down = [downconv] 140 | up = [uprelu, upconv, nn.Sigmoid()] 141 | model = down + [submodule] + up 142 | 143 | elif innermost: 144 | upconv = nn.ConvTranspose2d( 145 | inner_nc, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias 146 | ) 147 | down = [downrelu, downconv] 148 | up = [uprelu, upconv, upnorm] 149 | model = down + up 150 | 151 | else: # middle layers 152 | upconv = nn.ConvTranspose2d( 153 | inner_nc * 2, 154 | outer_nc, 155 | kernel_size=4, 156 | stride=2, 157 | padding=1, 158 | bias=use_bias, 159 | ) 160 | down = [downrelu, downconv, downnorm] 161 | up = [uprelu, upconv, upnorm] 162 | model = down + [submodule] + up 163 | if use_dropout: 164 | model = model + [nn.Dropout(0.5)] 165 | 166 | self.model = nn.Sequential(*model) 167 | 168 | def forward(self, x): 169 | if self.outermost: 170 | return self.model(x) 171 | else: # add skip connections 172 | return torch.cat([x, self.model(x)], dim=1) 173 | --------------------------------------------------------------------------------