├── .gitignore ├── LICENSE ├── README.md ├── assets └── CAE.png ├── codebase ├── __init__.py ├── config │ ├── config.yaml │ └── experiment │ │ ├── CAE_2Shapes.yaml │ │ ├── CAE_3Shapes.yaml │ │ └── CAE_MNISTShapes.yaml ├── main.py ├── model │ ├── ComplexAutoEncoder.py │ ├── ComplexDecoder.py │ ├── ComplexEncoder.py │ ├── ComplexLayers.py │ ├── __init__.py │ └── model_utils.py └── utils │ ├── __init__.py │ ├── data_utils.py │ ├── eval_utils.py │ └── utils.py ├── environment.yml └── setup.sh /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Sindy Löwe 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Complex-Valued AutoEncoders for Object Discovery 2 | 3 | We present the Complex AutoEncoder – an object discovery approach that takes inspiration from neuroscience to implement distributed object-centric representations. 4 | After introducing complex-valued activations into a convolutional autoencoder, it learns to encode feature information in the activations’ magnitudes and object affiliation in their phase values. 5 | 6 | This repo provides a reference implementation for the Complex AutoEncoder (CAE) as introduced in our paper "Complex-Valued AutoEncoders for Object Discovery" ([https://arxiv.org/abs/2204.02075](https://arxiv.org/abs/2204.02075)) by Sindy Löwe, Phillip Lippe, Maja Rudolph and Max Welling. 7 | 8 | Model figure 9 | 10 | 11 | ## Setup 12 | 13 | Make sure you have conda installed (you can find instructions [here](https://www.anaconda.com/products/distribution)). 14 | 15 | Then, run ```bash setup.sh``` to download the 2Shapes, 3Shapes and MNIST&Shapes datasets and to create a conda environment with all required packages. 16 | 17 | The script installs PyTorch with CUDA 11.3, which is the version we used for our experiments. If you want to use a different version, you can change the version number in the ```setup.sh``` script. 18 | 19 | 20 | ## Run Experiments 21 | 22 | To train and test the CAE, run one of the following commands, depending on the dataset you want to use: 23 | 24 | ```python -m codebase.main +experiment=CAE_2Shapes``` 25 | 26 | ```python -m codebase.main +experiment=CAE_3Shapes``` 27 | 28 | ```python -m codebase.main +experiment=CAE_MNISTShapes``` 29 | 30 | 31 | ## Citation 32 | When using this code, please cite our paper: 33 | 34 | ``` 35 | @article{lowe2022complex, 36 | title={Complex-Valued Autoencoders for Object Discovery}, 37 | author={L{\"o}we, Sindy and Lippe, Phillip and Rudolph, Maja and Welling, Max}, 38 | journal={Transactions on Machine Learning Research (TMLR)}, 39 | year={2022} 40 | } 41 | ``` 42 | 43 | ## Contact 44 | For questions and suggestions, feel free to open an issue on GitHub or send an email to [loewe.sindy@gmail.com](mailto:loewe.sindy@gmail.com). 45 | -------------------------------------------------------------------------------- /assets/CAE.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/loeweX/ComplexAutoEncoder/3328961c64b4a0a63db4139ef11a8dd29a3335eb/assets/CAE.png -------------------------------------------------------------------------------- /codebase/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/loeweX/ComplexAutoEncoder/3328961c64b4a0a63db4139ef11a8dd29a3335eb/codebase/__init__.py -------------------------------------------------------------------------------- /codebase/config/config.yaml: -------------------------------------------------------------------------------- 1 | use_cuda: True 2 | seed: 42 3 | 4 | defaults: 5 | - _self_ # Override values within this file with values in selected files. 6 | 7 | 8 | input: 9 | load_path: datasets 10 | batch_size: 64 11 | channel: 1 12 | image_height: 32 13 | image_width: 32 14 | 15 | 16 | model: 17 | hidden_dim: 32 # Only used for convolutional layers. 18 | linear_dim: 64 19 | 20 | 21 | training: 22 | learning_rate: 1e-3 23 | learning_rate_schedule: 1 # 0 - constant lr; 1 - warm-up 24 | warmup_steps: 500 25 | 26 | print_idx: 5000 27 | val_idx: -1 # -1 - no validation; X - every X steps. 28 | 29 | 30 | evaluation: 31 | phase_mask_threshold: 0.1 # Threshold on minimum magnitude to use when evaluating phases; -1: no masking. 32 | 33 | 34 | hydra: 35 | run: 36 | dir: ./outputs/${now:%Y-%m-%d}/${now:%H-%M-%S.%f} 37 | 38 | -------------------------------------------------------------------------------- /codebase/config/experiment/CAE_2Shapes.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | input: 4 | file_name: 2shapes 5 | 6 | training: 7 | steps: 10000 # How many times do we train on a single batch? 8 | 9 | -------------------------------------------------------------------------------- /codebase/config/experiment/CAE_3Shapes.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | input: 4 | file_name: 3shapes 5 | 6 | training: 7 | steps: 100000 # How many times do we train on a single batch? 8 | 9 | -------------------------------------------------------------------------------- /codebase/config/experiment/CAE_MNISTShapes.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | input: 4 | file_name: MNIST_shapes 5 | 6 | training: 7 | steps: 10000 # How many times do we train on a single batch? 8 | 9 | -------------------------------------------------------------------------------- /codebase/main.py: -------------------------------------------------------------------------------- 1 | import time 2 | from collections import defaultdict 3 | 4 | import hydra 5 | import torch 6 | from omegaconf import DictConfig 7 | from tqdm import tqdm 8 | from datetime import timedelta 9 | 10 | from codebase.model import model_utils 11 | from codebase.utils import data_utils, utils 12 | 13 | 14 | def train(opt, model, optimizer, train_iterator, train_loader): 15 | start_time = time.time() 16 | 17 | for step in range(opt.training.steps + 1): 18 | input_images, labels = data_utils.get_input(opt, train_iterator, train_loader) 19 | 20 | optimizer, lr = utils.update_learning_rate(optimizer, opt, step) 21 | 22 | optimizer.zero_grad() 23 | outputs = model(input_images, labels, step) 24 | outputs["loss"].backward() 25 | optimizer.step() 26 | 27 | if step % opt.training.val_idx == 0 and opt.training.val_idx != -1: 28 | validate_or_test(opt, step, model, "val") 29 | 30 | total_train_time = time.time() - start_time 31 | print(f"Total training time: {timedelta(seconds=total_train_time)}") 32 | 33 | model_utils.save_model(opt, model, optimizer) 34 | return optimizer, step 35 | 36 | 37 | def validate_or_test(opt, step, model, partition): 38 | data_loader, data_iterator = data_utils.get_data(opt, partition) 39 | 40 | test_results = defaultdict(float) 41 | 42 | model.eval() 43 | print(partition) 44 | 45 | test_time = time.time() 46 | 47 | with torch.no_grad(): 48 | for _ in tqdm(range(len(data_loader))): 49 | input_images, labels = data_utils.get_input(opt, data_iterator, data_loader) 50 | 51 | outputs = model( 52 | input_images, labels, step, partition=partition, evaluate=True 53 | ) 54 | 55 | test_results["loss"] += outputs["loss"] / len(data_loader) 56 | test_results["ARI+BG"] += outputs["ARI+BG"] / len(data_loader) 57 | test_results["ARI-BG"] += outputs["ARI-BG"] / len(data_loader) 58 | 59 | utils.print_results(partition, step, time.time() - test_time, test_results) 60 | 61 | model.train() 62 | 63 | 64 | @hydra.main(config_path="config", config_name="config") 65 | def my_main(opt: DictConfig) -> None: 66 | opt = utils.parse_args(opt) 67 | 68 | model, optimizer = model_utils.get_model_and_optimizer(opt) 69 | train_loader, train_iterator = data_utils.get_data(opt, "train") 70 | 71 | optimizer, step = train(opt, model, optimizer, train_iterator, train_loader) 72 | 73 | validate_or_test(opt, step, model, "test") 74 | 75 | 76 | if __name__ == "__main__": 77 | my_main() 78 | -------------------------------------------------------------------------------- /codebase/model/ComplexAutoEncoder.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import torch 4 | import torch.nn as nn 5 | from einops import rearrange 6 | 7 | from codebase.model import ComplexDecoder, ComplexEncoder 8 | from codebase.utils import utils, eval_utils 9 | 10 | 11 | class ComplexAutoEncoder(nn.Module): 12 | def __init__(self, opt): 13 | super(ComplexAutoEncoder, self).__init__() 14 | 15 | self.opt = opt 16 | 17 | self.encoder = ComplexEncoder.ComplexEncoder(opt) 18 | self.decoder = ComplexDecoder.ComplexDecoder(opt, self.encoder.hidden_dim) 19 | 20 | self.output_model = nn.Conv2d( 21 | self.opt.input.channel, self.opt.input.channel, 1, 1 22 | ) 23 | self._init_output_model() 24 | 25 | def _init_output_model(self): 26 | nn.init.constant_(self.output_model.weight, 1) 27 | nn.init.constant_(self.output_model.bias, 0) 28 | 29 | def _prepare_input(self, input_images): 30 | phase = torch.zeros_like(input_images) 31 | return utils.get_complex_number(input_images, phase) 32 | 33 | def _run_evaluation(self, outputs, labels): 34 | outputs = eval_utils.apply_kmeans(self.opt, outputs, labels) 35 | 36 | outputs["ARI+BG"] = eval_utils.calc_ari_score( 37 | self.opt, labels, outputs["labels_pred"], with_background=True 38 | ) 39 | outputs["ARI-BG"] = eval_utils.calc_ari_score( 40 | self.opt, labels, outputs["labels_pred"], with_background=False 41 | ) 42 | 43 | return outputs 44 | 45 | def _log_outputs(self, complex_output, reconstruction, outputs): 46 | outputs["reconstruction"] = reconstruction 47 | outputs["phase"] = complex_output.angle() 48 | outputs["norm_magnitude"] = utils.clip_and_rescale( 49 | complex_output.abs(), self.opt.evaluation.phase_mask_threshold 50 | ) 51 | return outputs 52 | 53 | def _apply_module(self, module, channel_norm, z): 54 | m, phi = module(z) 55 | z = self._apply_activation_function(m, phi, channel_norm) 56 | return z 57 | 58 | def _apply_activation_function(self, m, phi, channel_norm): 59 | m = channel_norm(m) 60 | m = torch.nn.functional.relu(m) 61 | return utils.get_complex_number(m, phi) 62 | 63 | def _apply_conv_layers(self, model, z): 64 | for idx, _ in enumerate(model.conv_model): 65 | z = self._apply_module(model.conv_model[idx], model.channel_norm[idx], z) 66 | return z 67 | 68 | def encode(self, x): 69 | # Apply convolutional layers. 70 | z = self._apply_conv_layers(self.encoder, x) 71 | 72 | # Apply linear layer. 73 | z = rearrange(z, "b c h w -> b (c h w)") 74 | z = self._apply_module(self.encoder.linear, self.encoder.channel_norm[-1], z) 75 | 76 | return z 77 | 78 | def decode(self, z): 79 | # Apply linear layer. 80 | z = self._apply_module(self.decoder.linear, self.decoder.channel_norm[-1], z) 81 | 82 | z = rearrange( 83 | z, 84 | "b (c h w) -> b c h w", 85 | b=self.opt.input.batch_size, 86 | h=self.decoder.hidden_dim[0], 87 | w=self.decoder.hidden_dim[1], 88 | ) 89 | 90 | # Apply convolutional layers. 91 | complex_output = self._apply_conv_layers(self.decoder, z) 92 | 93 | # Handle output. 94 | output_magnitude = complex_output.abs() 95 | reconstruction = self.output_model(output_magnitude) 96 | reconstruction = torch.sigmoid(reconstruction) 97 | 98 | return reconstruction, complex_output 99 | 100 | def forward( 101 | self, input_images, labels, step, partition="train", evaluate=False, 102 | ): 103 | start_time = time.time() 104 | complex_input = self._prepare_input(input_images) 105 | 106 | z = self.encode(complex_input) 107 | reconstruction, complex_output = self.decode(z) 108 | 109 | outputs = {"loss": nn.functional.mse_loss(reconstruction, input_images)} 110 | 111 | if step % self.opt.training.print_idx == 0 or evaluate: 112 | outputs = self._log_outputs(complex_output, reconstruction, outputs) 113 | outputs = self._run_evaluation(outputs, labels) 114 | 115 | if partition == "train": 116 | utils.print_results(partition, step, time.time() - start_time, outputs) 117 | 118 | return outputs 119 | -------------------------------------------------------------------------------- /codebase/model/ComplexDecoder.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from codebase.model import ComplexLayers, model_utils 4 | 5 | 6 | class ComplexDecoder(nn.Module): 7 | def __init__(self, opt, hidden_dim): 8 | super().__init__() 9 | 10 | self.opt = opt 11 | 12 | self.out_channel = [ 13 | 2 * self.opt.model.hidden_dim, 14 | 2 * self.opt.model.hidden_dim, 15 | self.opt.model.hidden_dim, 16 | self.opt.model.hidden_dim, 17 | self.opt.input.channel, 18 | ] 19 | 20 | self.conv_model = nn.ModuleList( 21 | [ 22 | ComplexLayers.ComplexConvTranspose2d( 23 | opt, 24 | 2 * self.opt.model.hidden_dim, 25 | self.out_channel[0], 26 | kernel_size=3, 27 | output_padding=1, 28 | padding=1, 29 | stride=2, 30 | ), # e.g. 4x4 => 8x8. 31 | ComplexLayers.ComplexConv2d( 32 | opt, 33 | self.out_channel[0], 34 | self.out_channel[1], 35 | kernel_size=3, 36 | padding=1, 37 | ), 38 | ComplexLayers.ComplexConvTranspose2d( 39 | opt, 40 | self.out_channel[1], 41 | self.out_channel[2], 42 | kernel_size=3, 43 | output_padding=1, 44 | padding=1, 45 | stride=2, 46 | ), # e.g. 8x8 => 16x16. 47 | ComplexLayers.ComplexConv2d( 48 | opt, 49 | self.out_channel[2], 50 | self.out_channel[3], 51 | kernel_size=3, 52 | padding=1, 53 | ), 54 | ComplexLayers.ComplexConvTranspose2d( 55 | opt, 56 | self.out_channel[3], 57 | self.out_channel[4], 58 | kernel_size=3, 59 | output_padding=1, 60 | padding=1, 61 | stride=2, 62 | ), # e.g. 16x16 => 32x32. 63 | ] 64 | ) 65 | 66 | self.hidden_dim = hidden_dim 67 | 68 | linear_out = ( 69 | 2 * self.hidden_dim[0] * self.hidden_dim[1] * self.opt.model.hidden_dim 70 | ) 71 | self.linear = ComplexLayers.ComplexLinear( 72 | opt, self.opt.model.linear_dim, linear_out 73 | ) 74 | 75 | self.channel_norm = model_utils.init_channel_norm_2d( 76 | self.out_channel, linear_out, self.opt 77 | ) 78 | -------------------------------------------------------------------------------- /codebase/model/ComplexEncoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from codebase.model import ComplexLayers, model_utils 5 | 6 | 7 | class ComplexEncoder(nn.Module): 8 | def __init__(self, opt): 9 | super().__init__() 10 | 11 | self.opt = opt 12 | 13 | self.out_channel = [ 14 | self.opt.model.hidden_dim, 15 | self.opt.model.hidden_dim, 16 | 2 * self.opt.model.hidden_dim, 17 | 2 * self.opt.model.hidden_dim, 18 | 2 * self.opt.model.hidden_dim, 19 | ] 20 | 21 | self.conv_model = nn.ModuleList( 22 | [ 23 | ComplexLayers.ComplexConv2d( 24 | opt, 25 | self.opt.input.channel, 26 | self.out_channel[0], 27 | kernel_size=3, 28 | padding=1, 29 | stride=2, 30 | ), # e.g. 32x32 => 16x16. 31 | ComplexLayers.ComplexConv2d( 32 | opt, 33 | self.out_channel[0], 34 | self.out_channel[1], 35 | kernel_size=3, 36 | padding=1, 37 | ), 38 | ComplexLayers.ComplexConv2d( 39 | opt, 40 | self.out_channel[1], 41 | self.out_channel[2], 42 | kernel_size=3, 43 | padding=1, 44 | stride=2, 45 | ), # e.g. 16x16 => 8x8. 46 | ComplexLayers.ComplexConv2d( 47 | opt, 48 | self.out_channel[2], 49 | self.out_channel[3], 50 | kernel_size=3, 51 | padding=1, 52 | ), 53 | ComplexLayers.ComplexConv2d( 54 | opt, 55 | self.out_channel[3], 56 | self.out_channel[4], 57 | kernel_size=3, 58 | padding=1, 59 | stride=2, 60 | ), # e.g. 8x8 => 4x4. 61 | ] 62 | ) 63 | 64 | self.hidden_dim = self.get_hidden_dimension() 65 | self.linear = ComplexLayers.ComplexLinear( 66 | opt, 67 | 2 * self.hidden_dim[0] * self.hidden_dim[1] * self.opt.model.hidden_dim, 68 | self.opt.model.linear_dim, 69 | ) 70 | 71 | self.channel_norm = model_utils.init_channel_norm_2d( 72 | self.out_channel, self.opt.model.linear_dim, self.opt 73 | ) 74 | 75 | def get_hidden_dimension(self): 76 | x = torch.zeros( 77 | 1, 78 | self.opt.input.channel, 79 | self.opt.input.image_height, 80 | self.opt.input.image_width, 81 | ) 82 | for module in self.conv_model: 83 | x = module.conv(x) 84 | 85 | return x.shape[2], x.shape[3] 86 | -------------------------------------------------------------------------------- /codebase/model/ComplexLayers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from codebase.model import model_utils 5 | from codebase.utils import utils 6 | 7 | 8 | def apply_layer(real_function, phase_bias, magnitude_bias, x): 9 | psi = real_function(x.real) + 1j * real_function(x.imag) 10 | m_psi = psi.abs() + magnitude_bias 11 | phi_psi = utils.stable_angle(psi) + phase_bias 12 | 13 | chi = real_function(x.abs()) + magnitude_bias 14 | m = 0.5 * m_psi + 0.5 * chi 15 | 16 | return m, phi_psi 17 | 18 | 19 | class ComplexConvTranspose2d(nn.Module): 20 | def __init__( 21 | self, 22 | opt, 23 | in_channels, 24 | out_channels, 25 | kernel_size, 26 | stride=1, 27 | padding=0, 28 | output_padding=0, 29 | ): 30 | super(ComplexConvTranspose2d, self).__init__() 31 | 32 | self.opt = opt 33 | 34 | self.conv_tran = nn.ConvTranspose2d( 35 | in_channels, 36 | out_channels, 37 | kernel_size, 38 | stride, 39 | padding, 40 | output_padding, 41 | bias=False, 42 | ) 43 | 44 | self.kernel_size = torch.nn.modules.utils._pair(kernel_size) 45 | fan_in = out_channels * self.kernel_size[0] * self.kernel_size[1] 46 | self.magnitude_bias, self.phase_bias = model_utils.get_conv_biases( 47 | out_channels, fan_in 48 | ) 49 | 50 | def forward(self, x): 51 | return apply_layer(self.conv_tran, self.phase_bias, self.magnitude_bias, x) 52 | 53 | 54 | class ComplexConv2d(nn.Module): 55 | def __init__( 56 | self, opt, in_channels, out_channels, kernel_size=3, stride=1, padding=0, 57 | ): 58 | super(ComplexConv2d, self).__init__() 59 | 60 | self.opt = opt 61 | 62 | self.conv = nn.Conv2d( 63 | in_channels, out_channels, kernel_size, stride, padding, bias=False, 64 | ) 65 | 66 | self.kernel_size = torch.nn.modules.utils._pair(kernel_size) 67 | fan_in = in_channels * self.kernel_size[0] * self.kernel_size[1] 68 | self.magnitude_bias, self.phase_bias = model_utils.get_conv_biases( 69 | out_channels, fan_in 70 | ) 71 | 72 | def forward(self, x): 73 | return apply_layer(self.conv, self.phase_bias, self.magnitude_bias, x) 74 | 75 | 76 | class ComplexLinear(nn.Module): 77 | def __init__(self, opt, in_channels, out_channels): 78 | super(ComplexLinear, self).__init__() 79 | 80 | self.opt = opt 81 | 82 | self.fc = nn.Linear(in_channels, out_channels, bias=False) 83 | 84 | self.magnitude_bias, self.phase_bias = self._get_biases( 85 | in_channels, out_channels 86 | ) 87 | 88 | def _get_biases(self, in_channels, out_channels): 89 | fan_in = in_channels 90 | magnitude_bias = nn.Parameter(torch.empty((1, out_channels))) 91 | magnitude_bias = model_utils.init_magnitude_bias(fan_in, magnitude_bias) 92 | 93 | phase_bias = nn.Parameter(torch.empty((1, out_channels))) 94 | phase_bias = model_utils.init_phase_bias(phase_bias) 95 | return magnitude_bias, phase_bias 96 | 97 | def forward(self, x): 98 | return apply_layer(self.fc, self.phase_bias, self.magnitude_bias, x) 99 | -------------------------------------------------------------------------------- /codebase/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/loeweX/ComplexAutoEncoder/3328961c64b4a0a63db4139ef11a8dd29a3335eb/codebase/model/__init__.py -------------------------------------------------------------------------------- /codebase/model/model_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | from codebase.model import ComplexAutoEncoder 8 | 9 | 10 | def get_model_and_optimizer(opt): 11 | model = ComplexAutoEncoder.ComplexAutoEncoder(opt) 12 | 13 | print(model) 14 | print() 15 | 16 | if opt.use_cuda: 17 | model = model.cuda() 18 | 19 | optimizer = torch.optim.Adam(model.parameters(), lr=opt.training.learning_rate) 20 | 21 | return model, optimizer 22 | 23 | 24 | def init_channel_norm_2d(out_channel, linear_out, opt): 25 | channel_norm = nn.ModuleList([None] * (len(out_channel) + 1)) 26 | for idx, out_c in enumerate(out_channel): 27 | channel_norm[idx] = nn.BatchNorm2d(out_c, affine=True) 28 | 29 | channel_norm[-1] = nn.LayerNorm(linear_out, elementwise_affine=True) 30 | return channel_norm 31 | 32 | 33 | def save_model(opt, model, optimizer): 34 | file_path = os.path.join(opt.log_dir, "checkpoint.pt") 35 | print(f"Saving model to {file_path}.") 36 | torch.save( 37 | {"model": model.state_dict(), "optimizer": optimizer.state_dict()}, file_path, 38 | ) 39 | 40 | 41 | def get_conv_biases(out_channels, fan_in): 42 | magnitude_bias = nn.Parameter(torch.empty((1, out_channels, 1, 1))) 43 | magnitude_bias = init_magnitude_bias(fan_in, magnitude_bias) 44 | 45 | phase_bias = nn.Parameter(torch.empty((1, out_channels, 1, 1))) 46 | phase_bias = init_phase_bias(phase_bias) 47 | return magnitude_bias, phase_bias 48 | 49 | 50 | def init_phase_bias(bias): 51 | return nn.init.constant_(bias, val=0) 52 | 53 | 54 | def init_magnitude_bias(fan_in, bias): 55 | bound = 1 / math.sqrt(fan_in) 56 | torch.nn.init.uniform_(bias, -bound, bound) 57 | return bias 58 | -------------------------------------------------------------------------------- /codebase/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/loeweX/ComplexAutoEncoder/3328961c64b4a0a63db4139ef11a8dd29a3335eb/codebase/utils/__init__.py -------------------------------------------------------------------------------- /codebase/utils/data_utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | from pathlib import Path 3 | 4 | import numpy as np 5 | import torch 6 | from torch.utils.data import Dataset 7 | 8 | 9 | class NpzDataset(torch.utils.data.Dataset): 10 | """NpzDataset: loads a npz file as input.""" 11 | 12 | def __init__(self, opt, partition): 13 | self.opt = opt 14 | self.root_dir = Path(opt.cwd, opt.input.load_path) 15 | file_name = Path(self.root_dir, f"{opt.input.file_name}_{partition}.npz") 16 | 17 | self.dataset = np.load(file_name) 18 | self.images = torch.Tensor(self.dataset["images"]) 19 | self.labels = self.dataset["labels"] 20 | 21 | def __len__(self): 22 | return self.images.shape[0] 23 | 24 | def __getitem__(self, idx): 25 | images = (self.images[idx] + 1) / 2 # Normalize to [0, 1] range. 26 | return images, self.labels[idx] 27 | 28 | 29 | def seed_worker(worker_id): 30 | worker_seed = torch.initial_seed() % 2 ** 32 31 | np.random.seed(worker_seed) 32 | random.seed(worker_seed) 33 | 34 | 35 | def get_dataloader(opt, dataset): 36 | # Improve reproducibility in dataloader. 37 | g = torch.Generator() 38 | g.manual_seed(opt.seed) 39 | 40 | data_loader = torch.utils.data.DataLoader( 41 | dataset, 42 | batch_size=opt.input.batch_size, 43 | drop_last=True, 44 | shuffle=True, 45 | worker_init_fn=seed_worker, 46 | generator=g, 47 | num_workers=4, 48 | persistent_workers=True, 49 | ) 50 | iterator = iter(data_loader) 51 | 52 | return data_loader, iterator 53 | 54 | 55 | def get_data(opt, partition): 56 | dataset = NpzDataset(opt, partition) 57 | loader, iterator = get_dataloader(opt, dataset) 58 | return loader, iterator 59 | 60 | 61 | def get_input(opt, iterator, train_loader): 62 | try: 63 | input = next(iterator) 64 | except StopIteration: 65 | # Create new generator if the previous generator is exhausted. 66 | iterator = iter(train_loader) 67 | input = next(iterator) 68 | 69 | input_images, labels = input 70 | 71 | if opt.use_cuda: 72 | input_images = input_images.cuda() 73 | 74 | return input_images, labels 75 | -------------------------------------------------------------------------------- /codebase/utils/eval_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from einops import rearrange 4 | from sklearn.cluster import KMeans 5 | from sklearn.metrics.cluster import adjusted_rand_score 6 | 7 | from codebase.utils import utils 8 | 9 | 10 | def apply_kmeans(opt, outputs, labels_true): 11 | input_phase = utils.phase_to_cartesian_coordinates( 12 | opt, outputs["phase"], outputs["norm_magnitude"] 13 | ) 14 | 15 | input_phase = utils.tensor_to_numpy(input_phase) 16 | input_phase = rearrange(input_phase, "b p c h w -> b h w (c p)") 17 | 18 | num_clusters = int(torch.max(labels_true).item()) + 1 19 | 20 | labels_pred = ( 21 | np.zeros((opt.input.batch_size, opt.input.image_height, opt.input.image_width)) 22 | + num_clusters 23 | ) 24 | 25 | # Run k-means on each image separately. 26 | for img_idx in range(opt.input.batch_size): 27 | in_phase = input_phase[img_idx] 28 | num_clusters_img = int(torch.max(labels_true[img_idx]).item()) + 1 29 | 30 | # Remove areas in which objects overlap before k-means analysis. 31 | label_idx = np.where(labels_true[img_idx].cpu().numpy() != -1) 32 | in_phase = in_phase[label_idx] 33 | 34 | # Run k-means. 35 | k_means = KMeans(n_clusters=num_clusters_img, random_state=opt.seed).fit( 36 | in_phase 37 | ) 38 | 39 | # Create result image: fill in k_means labels & assign overlapping areas to class zero. 40 | cluster_img = ( 41 | np.zeros((opt.input.image_height, opt.input.image_width)) + num_clusters 42 | ) 43 | cluster_img[label_idx] = k_means.labels_ 44 | labels_pred[img_idx] = cluster_img 45 | 46 | outputs["labels_pred"] = labels_pred 47 | return outputs 48 | 49 | 50 | def calc_ari_score(opt, labels_true, labels_pred, with_background): 51 | ari = 0 52 | for idx in range(opt.input.batch_size): 53 | if with_background: 54 | area_to_eval = np.where( 55 | labels_true[idx] > -1 56 | ) # Remove areas in which objects overlap. 57 | else: 58 | area_to_eval = np.where( 59 | labels_true[idx] > 0 60 | ) # Remove background & areas in which objects overlap. 61 | 62 | ari += adjusted_rand_score( 63 | labels_true[idx][area_to_eval], labels_pred[idx][area_to_eval] 64 | ) 65 | return ari / opt.input.batch_size 66 | -------------------------------------------------------------------------------- /codebase/utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import random 4 | from datetime import timedelta 5 | 6 | import numpy as np 7 | import torch 8 | from hydra.utils import get_original_cwd 9 | from omegaconf import OmegaConf 10 | from omegaconf import open_dict 11 | 12 | 13 | def parse_args(opt): 14 | with open_dict(opt): 15 | opt.log_dir = os.getcwd() 16 | print(f"Logging files in {opt.log_dir}") 17 | opt.device = "cuda:0" if opt.use_cuda else "cpu" 18 | opt.cwd = get_original_cwd() 19 | 20 | np.random.seed(opt.seed) 21 | torch.manual_seed(opt.seed) 22 | random.seed(opt.seed) 23 | 24 | save_opt(opt) 25 | print(OmegaConf.to_yaml(opt)) 26 | return opt 27 | 28 | 29 | def save_opt(opt): 30 | file_path = os.path.join(opt.log_dir, "opt.pkl") 31 | with open(file_path, "wb") as opt_file: 32 | pickle.dump(opt, opt_file) 33 | 34 | 35 | def get_learning_rate(opt, step): 36 | if opt.training.learning_rate_schedule == 0: 37 | return opt.training.learning_rate 38 | elif opt.training.learning_rate_schedule == 1: 39 | return get_linear_warmup_lr(opt, step) 40 | else: 41 | raise NotImplementedError 42 | 43 | 44 | def get_linear_warmup_lr(opt, step): 45 | if step < opt.training.warmup_steps: 46 | return opt.training.learning_rate * step / opt.training.warmup_steps 47 | else: 48 | return opt.training.learning_rate 49 | 50 | 51 | def update_learning_rate(optimizer, opt, step): 52 | lr = get_learning_rate(opt, step) 53 | optimizer.param_groups[0]["lr"] = lr 54 | return optimizer, lr 55 | 56 | 57 | def tensor_to_numpy(input_tensor): 58 | return input_tensor.detach().cpu().numpy() 59 | 60 | 61 | def spherical_to_cartesian_coordinates(x): 62 | # Second dimension of x contains spherical coordinates: (r, phi_1, ... phi_n). 63 | num_dims = x.shape[1] 64 | out = torch.zeros_like(x) 65 | 66 | r = x[:, 0] 67 | phi = x[:, 1:] 68 | 69 | sin_component = 1 70 | for i in range(num_dims - 1): 71 | out[:, i] = r * torch.cos(phi[:, i]) * sin_component 72 | sin_component = sin_component * torch.sin(phi[:, i]) 73 | 74 | out[:, -1] = r * sin_component 75 | return out 76 | 77 | 78 | def phase_to_cartesian_coordinates(opt, phase, norm_magnitude): 79 | # Map phases on unit-circle and transform to cartesian coordinates. 80 | unit_circle_phase = torch.concat( 81 | (torch.ones_like(phase)[:, None], phase[:, None]), dim=1 82 | ) 83 | 84 | if opt.evaluation.phase_mask_threshold != -1: 85 | # When magnitude is < phase_mask_threshold, use as multiplier to mask out respective phases from eval. 86 | unit_circle_phase = unit_circle_phase * norm_magnitude[:, None] 87 | 88 | return spherical_to_cartesian_coordinates(unit_circle_phase) 89 | 90 | 91 | def clip_and_rescale(input_tensor, clip_value): 92 | if torch.is_tensor(input_tensor): 93 | clipped = torch.clamp(input_tensor, min=0, max=clip_value) 94 | elif isinstance(input_tensor, np.ndarray): 95 | clipped = np.clip(input_tensor, a_min=0, a_max=clip_value) 96 | else: 97 | raise NotImplementedError 98 | 99 | return clipped * (1 / clip_value) 100 | 101 | 102 | def get_complex_number(magnitude, phase): 103 | return magnitude * torch.exp(phase * 1j) 104 | 105 | 106 | def complex_tensor_to_real(complex_tensor, dim=-1): 107 | return torch.stack([complex_tensor.real, complex_tensor.imag], dim=dim) 108 | 109 | 110 | def stable_angle(x: torch.tensor, eps=1e-8): 111 | """ Function to ensure that the gradients of .angle() are well behaved.""" 112 | imag = x.imag 113 | y = x.clone() 114 | y.imag[(imag < eps) & (imag > -1.0 * eps)] = eps 115 | return y.angle() 116 | 117 | 118 | def print_results(partition, step, time_spent, outputs): 119 | print( 120 | f"{partition} \t \t" 121 | f"Step: {step} \t" 122 | f"Time: {timedelta(seconds=time_spent)} \t" 123 | f"MSE Loss: {outputs['loss'].item():.4e} \t" 124 | f"ARI+BG: {outputs['ARI+BG']:.4e} \t" 125 | f"ARI-BG: {outputs['ARI-BG']:.4e} \t" 126 | ) 127 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: CAE 2 | channels: 3 | - defaults 4 | dependencies: 5 | - python=3.8 6 | - scikit-learn 7 | - numpy 8 | - scipy 9 | - tqdm 10 | -------------------------------------------------------------------------------- /setup.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Create conda environment. 4 | conda env create -f environment.yml 5 | conda activate CAE 6 | 7 | # Install additional packages. 8 | conda install pytorch=1.11 torchvision torchaudio cudatoolkit=11.3 -c pytorch 9 | pip install hydra-core 10 | pip install einops 11 | 12 | # Download datasets. 13 | wget https://www.dropbox.com/s/hcmin7jmem7pfn8/datasets.zip 14 | unzip datasets.zip 15 | rm datasets.zip 16 | --------------------------------------------------------------------------------