├── .gitignore ├── KDD22 ├── fdsa_poster.pdf └── fdsa_presentation.pdf ├── LICENSE ├── README.md ├── conda.yml ├── examples ├── README.md ├── molecule_generation │ ├── conda.yml │ ├── encoder_train_params.json │ ├── model_params.json │ └── mol_gen.py ├── set_autoencoder │ ├── conda.yml │ ├── reconstruct_128D_cancer_data │ │ ├── matching_params.json │ │ ├── setsae_setm_train.py │ │ └── train_params.json │ └── shapes │ │ ├── shapes_train.json │ │ └── shapes_train.py └── set_matching │ ├── conda.yml │ ├── rnn │ ├── birnn_params.json │ ├── rnn_params.json │ └── rnn_setmatching.py │ └── seq2seq │ ├── seq2seq_params.json │ └── seq2seq_setmatching.py ├── fdsa ├── __init__.py ├── datasets │ ├── __init__.py │ ├── galaxy_data.py │ ├── shapes_data.py │ ├── tests │ │ ├── test_galaxy_data.py │ │ └── test_shapes_data.py │ └── torch_dataset.py ├── models │ ├── __init__.py │ ├── decoders │ │ ├── __init__.py │ │ ├── decoder_sets_ae.py │ │ └── tests │ │ │ └── __init__.py │ ├── encoders │ │ ├── __init__.py │ │ ├── deepsets.py │ │ ├── encoder_sets_ae.py │ │ └── tests │ │ │ └── test_set_ae.py │ ├── set_matching │ │ ├── __init__.py │ │ ├── cnn.py │ │ ├── dnn.py │ │ ├── rnn.py │ │ ├── selectrnn.py │ │ ├── seq2seq.py │ │ ├── seq2seq_decoder.py │ │ └── seq2seq_encoder.py │ └── sets_autoencoder.py └── utils │ ├── __init__.py │ ├── gale_shapley.py │ ├── helper.py │ ├── hyperparameters.py │ ├── layers │ ├── __init__.py │ ├── peephole_lstm.py │ ├── select_item.py │ └── tests │ │ └── test_peephole_lstm.py │ ├── loss_setae.py │ ├── loss_setmatching.py │ ├── mapper.py │ └── setsae_setm.py ├── requirements.txt └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # mac files 10 | .DS_Store 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | pip-wheel-metadata/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | -------------------------------------------------------------------------------- /KDD22/fdsa_poster.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PaccMann/fdsa/0e592e5281df69b3da2ea004ad3d96d9ca286f4d/KDD22/fdsa_poster.pdf -------------------------------------------------------------------------------- /KDD22/fdsa_presentation.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PaccMann/fdsa/0e592e5281df69b3da2ea004ad3d96d9ca286f4d/KDD22/fdsa_presentation.pdf -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 PaccMann 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) 2 | # Fully Differentiable Set Autoencoder (fdsa) 3 | 4 | A fully differentiable set autoencoder for encoding sets. [Paper @KDD 2022](https://dl.acm.org/doi/10.1145/3534678.3539153). 5 | 6 | 7 | The work is inspired by ["The Set Autoencoder: Unsupervised Representation Learning for Sets "](https://openreview.net/forum?id=r1tJKuyRZ). The model makes use of an 8 | encoder from ["Order Matters: Sequence to sequence for sets"](https://arxiv.org/abs/1511.06391) and the decoder is a slightly modified version of the one in ["The Set Autoencoder: Unsupervised Representation Learning for Sets "](https://openreview.net/forum?id=r1tJKuyRZ). To efficiently match the reconstructions of the autoencoder to their corresponding inputs to create a differentiable loss function, three architectures were developed and evaluated that could approximate the assignment problem and thus act as an end-to-end 9 | set matching network. The package includes code for these networks as well as baseline implementations of the set autoencoder fitted with the Hungarian matching algorithm and the Gale-Shapley algorithm. 10 | 11 | ## Installation 12 | 13 | Create a conda environment: 14 | 15 | ```console 16 | conda env create -f conda.yml 17 | ``` 18 | 19 | Activate the environment: 20 | 21 | ```console 22 | conda activate fdsa 23 | ``` 24 | 25 | Install: 26 | 27 | ```console 28 | pip install . 29 | ``` 30 | 31 | ### development 32 | 33 | Install in editable mode for development: 34 | 35 | ```sh 36 | pip install --user -e . 37 | ``` 38 | 39 | ## Examples 40 | 41 | For some examples on how to use `fdsa` see [here](./examples) 42 | 43 | ## Citation 44 | 45 | If you use `fdsa` in your projects, please cite: 46 | 47 | 48 | ```bib 49 | @inproceedings{10.1145/3534678.3539153, 50 | author = {Janakarajan, Nikita and Born, Jannis and Manica, Matteo}, 51 | title = {A Fully Differentiable Set Autoencoder}, 52 | year = {2022}, 53 | isbn = {9781450393850}, 54 | publisher = {Association for Computing Machinery}, 55 | address = {New York, NY, USA}, 56 | url = {https://doi.org/10.1145/3534678.3539153}, 57 | doi = {10.1145/3534678.3539153}, 58 | booktitle = {Proceedings of the 28th ACM SIGKDD Conference on Knowledge Discovery and Data Mining}, 59 | pages = {3061–3071}, 60 | numpages = {11}, 61 | keywords = {set matching network, multi-modality, autoencoders, sets}, 62 | location = {Washington DC, USA}, 63 | series = {KDD '22} 64 | } 65 | ``` 66 | -------------------------------------------------------------------------------- /conda.yml: -------------------------------------------------------------------------------- 1 | name: fdsa 2 | channels: 3 | - conda-forge 4 | dependencies: 5 | - python>=3.7,<3.8 6 | - pip>=19.1 7 | - pip: 8 | - numpy>=1.14.3 9 | - scipy>=1.3.1 10 | - pytoda @ git+https://github.com/PaccMann/paccmann_datasets@0.2.4 11 | - torch>=1.3.0 12 | - brc-pytorch >= 0.1.3 13 | - requests>=2.23.0 14 | # data analysis 15 | - pandas>=0.24.2,<1.0 16 | - scikit-learn==0.22.2 17 | - matplotlib>=3.1.1 18 | - seaborn>=0.9.0 19 | - astropy==4.0.1.post1 20 | - scikit-image==0.16.2 21 | - imageio==2.6.1 22 | # dev tools 23 | - pytest==5.4.2 24 | 25 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # fdsa - Examples 2 | 3 | Here we report some `fdsa` usage examples. 4 | Example data can be downloaded [here](https://ibm.box.com/v/paccmann-sets-data). 5 | 6 | 7 | ## Training a Set Matching Network 8 | 9 | The set matching network simply plays the role of a mapper in a set-autoencoder and is 10 | not optimised during training. Therefore, it must first be pre-trained before it can be 11 | plugged into the set autoencoder. 12 | 13 | The synthetic data to train the set matching network can be generated during runtime 14 | using the `pytoda` package. The dataset can be customised with respect to its 15 | distribution, shape, and the cost metric used to generate the assignment targets for 16 | supervised learning. See the training scripts in [set_matching](./set_matching/) folder 17 | on how to create and use these custom synthetic datasets. 18 | 19 | ### RNN Model 20 | The example [rnn_setmatching.py](./set_matching/rnn/rnn_setmatching.py) performs a 21 | comparative analysis of RNNs using 4 recurrent cells, namely, GRU, LSTM, nBRC and BRC 22 | in the set matching task. The parameters for training are provided as JSON files. 23 | 24 | ```console 25 | (fdsa) $ python examples/set_matching/rnn/rnn_setmatching.py -h 26 | usage: rnn_setmatching.py [-h] model_path results_path training_params 27 | 28 | positional arguments: 29 | model_path Path to save best model. 30 | results_path Path to save results and plots. 31 | training_params Path to the training parameters. 32 | 33 | optional arguments: 34 | -h, --help show this help message and exit 35 | ``` 36 | The RNN models return a set2 vs set1 constrained similarity matrix, and so the predictions 37 | are compared with `target21`. 38 | 39 | ### Sequence-to-Sequence Model 40 | The example [seq2seq_setmatching.py](./set_matching/seq2seq/seq2seq_setmatching.py) 41 | performs a comparative analysis of Sequence2Sequence models with attention using 4 42 | recurrent cells, namely, GRU, LSTM, nBRC and BRC as the encoding/decoding unit in the 43 | set matching task. The parameters for training are provided as JSON files. 44 | 45 | ```console 46 | (fdsa) $ python examples/set_matching/seq2seq/seq2seq_setmatching.py -h 47 | usage: seq2seq_setmatching.py [-h] model_path results_path training_params 48 | 49 | positional arguments: 50 | model_path Path to save best model. 51 | results_path Path to save results and plots. 52 | training_params Path to the training parameters. 53 | 54 | optional arguments: 55 | -h, --help show this help message and exit 56 | ``` 57 | 58 | The Seq2Seq models return a set1 vs set2 constrained similarity matrix, and so the 59 | predictions are compared with `target12`. 60 | 61 | NOTE: The dimensions along which the `argmax` is computed in the constrained similarity matrix 62 | decides whether `target12` or `target21` are used as the true targets for loss calculation. 63 | `target12` are the row-wise non-zero indices in a set1 vs set2 constrained similarity matrix, 64 | and `target21` are the row-wise non-zero indices in a set2 vs set1 constrained similarity matrix. 65 | 66 | 67 | 68 | ## Training a Fully Differentiable Set Autoencoder. 69 | 70 | The [set_autoencoder](./set_autoencoder/) folder contains 2 examples for two tasks - 71 | reconstructing shapes and reconstructing 128D cancer data. 72 | 73 | ### Reconstructing 2D Shapes 74 | This tasks serves as a sanity check for the set autoencoder and so only uses the Hungarian algorithm. 75 | The data for the shapes task consists of randomly generated 2D point clouds of squares, crosses, and circles, 76 | of various sizes and positions, and made up of a varying number of points. The data can be generated 77 | using the [shapes_data.py](../fdsa/datasets/shapes_data.py) script. 78 | The data are saved as a `.csv` file, which is then passed into the [shapes_train.py](./set_autoencoder/shapes/shapes_train.py) 79 | script. This script performs a comparative analysis of using GRU, LSTM, pLSTM and nBRC recurrent cells 80 | as the encoding/decoding unit in the set autoencoder on reconstructing 2D shapes. The 81 | training parameters are provided as a JSON file. 82 | 83 | ```console 84 | (fdsa) $ python examples/set_autoencoder/shapes/shapes_train.py -h 85 | usage: shapes_train.py [-h] 86 | model_path results_path training_data_path 87 | validation_data_path testing_data_path training_params 88 | 89 | positional arguments: 90 | model_path Path to save the best performing model. 91 | results_path Path to save the training and validation losses. 92 | training_data_path Path to the training data. 93 | validation_data_path Path to the validation data. 94 | testing_data_path Path to the testing data. 95 | training_params Path to the training parameters json file. 96 | 97 | optional arguments: 98 | -h, --help show this help message and exit 99 | ``` 100 | ### Reconstructing 128D Cancer Data 101 | 102 | The data are provided as `torch` files that contain tensors of transcriptomic and protein data 103 | combined into sets, and split into training, validation and test datasets, such that no cell line 104 | or protein seen during training are validated or tested. Additionally, the permutations 105 | used in shuffling the order of the elements of the set are also provided. 106 | The files can be found [here](https://ibm.box.com/v/paccmann-sets-data). 107 | 108 | The training script [setsae_setm_train.py](./set_autoencoder/reconstruct_128D/setsae_setm_train.py) 109 | generates reconstructions for the specified parameters, which are provided as JSON files. 110 | The Hungarian or Gale-Shapley algorithm can be used in place of the pre-trained network 111 | as a baseline by commenting and uncommenting the lines pertaining to the mapper. 112 | Note: Even if the experiment is for the Hungarian/Gale-Shapley algorithm, a dummy file for matching parameters and 113 | model path should be provided to maintain the flexibility of switching between the 114 | network and algorithms. 115 | 116 | The pre-trained matching network can be found [here](https://ibm.box.com/v/paccmann-sets-matching-network). 117 | 118 | ```console 119 | (fdsa) $ python examples/set_autoencoder/reconstruct_128D/setsae_setm_train.py 120 | usage: setsae_setm_train.py [-h] 121 | model_path results_path training_params 122 | matching_params train_data_path valid_data_path 123 | test_data_path 124 | 125 | positional arguments: 126 | model_path Path where the pre-trained matching network is saved. 127 | results_path Path to save the results, logs and best model. 128 | training_params Path to the training parameters json file. 129 | matching_params Path to the matching network parameters json file. 130 | train_data_path Path to the training data. 131 | valid_data_path Path to the validation data. 132 | test_data_path Path to the testing data. 133 | 134 | optional arguments: 135 | -h, --help show this help message and exit 136 | ``` 137 | 138 | 139 | ## Molecule Generation 140 | 141 | The example [mol_gen.py](./molecule_generation/mol_gen.py) uses the encoder of the pre-trained set autoencoder to produce embeddings of sets of transcriptomic and proteomic data. This embedding acts as a multi-modal context for molecule generation in the Paccmann^{RL} generative model to generate candidate drugs against a given cancer type. 142 | 143 | The experiment makes use of sets of transcriptomic profiles of cell lines and associated proteins to condition molecule generation. A LOOCV on the cell-lines is used to evaluate the performance of this model, that is, the test set comprises a cartesian product of the test cell-line and all proteins under consideration. 144 | 145 | The [molecule_generation](./molecule_generation) folder contains the necessary JSON parameter files for training. The omics data, protein data and unbiased predictions can be found here [here](https://ibm.box.com/v/paccmann-sets-data). The pre-trained models required to succesfully run the script can be found as follows: 146 | 1. Encoder model (encoder_model_path, encoder_params_path) and parameters can be found [here](https://ibm.box.com/v/paccmann-sets-autoencoder) 147 | 2. Molecule model (mol_model_path) can be found [here](https://ibm.box.com/v/paccmann-affinity-selfies024) 148 | 3. IC50 model (ic50_model_path) can be found [here](https://ibm.box.com/v/paccmann-pytoda-ic50) 149 | 4. Affinity model (affinity_model_path) can be found [here](https://ibm.box.com/v/paccmann-affinity-base) 150 | 5. Tox21 model (tox21_path) can be found [here](https://ibm.ent.box.com/folder/122603684362?v=paccmann-sarscov2-data) 151 | under pretraining/toxicity_predictor. 152 | 153 | ```console 154 | (fdsa) $ python examples/molecule_generation/mol_gen.py -h 155 | usage: mol_gen.py [-h] [--test_protein_name TEST_PROTEIN_NAME] 156 | [--tox21_path TOX21_PATH] 157 | omics_data_path protein_data_path test_cell_line 158 | encoder_model_path mol_model_path ic50_model_path 159 | affinity_model_path params_path encoder_params_path 160 | results_path unbiased_protein_path unbiased_omics_path site 161 | model_name 162 | 163 | PaccMann^RL training script 164 | 165 | positional arguments: 166 | omics_data_path Omics data path to condition molecule generation. 167 | protein_data_path Protein data path to condition molecule generation. 168 | test_cell_line Name of testing cell line (LOOCV). 169 | encoder_model_path Path to setAE model. 170 | mol_model_path Path to chemistry model. 171 | ic50_model_path Path to pretrained IC50 model. 172 | affinity_model_path Path to pretrained affinity model. 173 | params_path Directory containing the model params JSON file. 174 | encoder_params_path directory containing the encoder parameters JSON file. 175 | results_path Path where results are saved. 176 | unbiased_protein_path 177 | Path where unbiased protein predictions are saved. 178 | unbiased_omics_path Path where unbiased omics predictions are saved. 179 | site Name of the cancer site for conditioning generation. 180 | model_name Name for the trained model. 181 | 182 | optional arguments: 183 | -h, --help show this help message and exit 184 | --test_protein_name TEST_PROTEIN_NAME 185 | Optional gene name of testing protein (LOOCV). 186 | --tox21_path TOX21_PATH 187 | Optional path to Tox21 model. 188 | ``` 189 | 190 | For more examples see other repositories in the [PaccMann organization](https://github.com/PaccMann). 191 | 192 | -------------------------------------------------------------------------------- /examples/molecule_generation/conda.yml: -------------------------------------------------------------------------------- 1 | name: molecule_generation 2 | channels: 3 | - https://conda.anaconda.org/rdkit 4 | dependencies: 5 | - rdkit=2019.03.1 6 | - python>=3.6,<3.8 7 | - pip>=19.1 8 | - pip: 9 | - pytoda @ git+https://git@github.com/PaccMann/paccmann_datasets@0.1.1 10 | - paccmann_chemistry @ git+https://git@github.com/PaccMann/paccmann_chemistry@0.0.4 11 | - paccmann_predictor @ git+https://github.com/PaccMann/paccmann_predictor@0.0.3 12 | - paccmann_generator @ git+https://github.com/PaccMann/paccmann_generator@0.0.2 13 | - paccmann_omics @ git+https://github.com/PaccMann/paccmann_omics@0.0.1 14 | - fdsa @ git+https://github.com/PaccMann/fdsa@0.0.1 15 | - numpy>=1.14.3 16 | - pandas>=0.24.1 17 | - torch>=1.0.1 18 | - matplotlib>=2.2.2 19 | - seaborn>=0.9.0 20 | -------------------------------------------------------------------------------- /examples/molecule_generation/encoder_train_params.json: -------------------------------------------------------------------------------- 1 | { 2 | "cell": "pLSTM", 3 | "input_size": 128, 4 | "hidden_size_linear": 256, 5 | "hidden_size_encoder": 256, 6 | "hidden_size_decoder": 256, 7 | "mapper": "rnn", 8 | "matcher": "HM", 9 | "p": 2, 10 | "loss": "CrossEntropy", 11 | "epochs": 20, 12 | "lr": 0.0001, 13 | "scheduler": "step", 14 | "lr_args": [ 15 | 10, 16 | 1.0 17 | ], 18 | "train_batch": 128, 19 | "test_batch": 128, 20 | "valid_batch": 128, 21 | "test_size": 50000, 22 | "train_size": 499082, 23 | "valid_size": 50000, 24 | "padding_value": 4.0, 25 | "masking": "False", 26 | "num_plots": 10 27 | } -------------------------------------------------------------------------------- /examples/molecule_generation/model_params.json: -------------------------------------------------------------------------------- 1 | { 2 | "generate_len": 100, 3 | "num_molecules": 500, 4 | "temperature": 0.7, 5 | "predictor_smiles_length": 560, 6 | "weight_decay": 0.00001, 7 | "learning_rate": 0.00003, 8 | "gamma": 1, 9 | "IC50_min": -8.77435, 10 | "IC50_max": 11.83146, 11 | "IC50_threshold": 2, 12 | "clip_grad": 0.1, 13 | "epochs": 30, 14 | "steps": 5, 15 | "batch_size": 128, 16 | "eval_batch_size": 190, 17 | "qed_weight": 0, 18 | "scscore_weight": 0, 19 | "esol_weight": 0, 20 | "paccmann_weight": 2, 21 | "clintox_weight": 0, 22 | "organdb_weight": 0, 23 | "sider_weight": 0, 24 | "tox21_weight": 0, 25 | "site": "brain", 26 | "model_folder": "", 27 | "tox21_path": "" 28 | } -------------------------------------------------------------------------------- /examples/molecule_generation/mol_gen.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import logging 4 | import os 5 | import sys 6 | import warnings 7 | from itertools import product 8 | 9 | import numpy as np 10 | import pandas as pd 11 | import torch 12 | 13 | from paccmann_chemistry.models import (StackGRUDecoder, StackGRUEncoder, TeacherVAE) 14 | from paccmann_chemistry.utils import get_device 15 | from paccmann_generator.plot_utils import ( 16 | plot_and_compare, plot_and_compare_proteins, plot_loss 17 | ) 18 | from paccmann_generator.reinforce_sets import ReinforceMultiModalSets 19 | from paccmann_generator.utils import disable_rdkit_logging 20 | from paccmann_predictor.models import MODEL_FACTORY 21 | from fdsa.models.sets_autoencoder import SetsAE 22 | from pytoda.proteins.protein_language import ProteinLanguage 23 | from pytoda.smiles.smiles_language import SMILESLanguage, SMILESTokenizer 24 | 25 | warnings.filterwarnings("ignore") 26 | 27 | parser = argparse.ArgumentParser(description='PaccMann^RL training script') 28 | 29 | parser.add_argument( 30 | 'omics_data_path', 31 | type=str, 32 | help='Omics data path to condition molecule generation.' 33 | ) 34 | parser.add_argument( 35 | 'protein_data_path', 36 | type=str, 37 | help='Protein data path to condition molecule generation.' 38 | ) 39 | 40 | parser.add_argument( 41 | 'test_cell_line', type=str, help='Name of testing cell line (LOOCV).' 42 | ) 43 | parser.add_argument('encoder_model_path', type=str, help='Path to setAE model.') 44 | 45 | parser.add_argument('mol_model_path', type=str, help='Path to chemistry model.') 46 | 47 | parser.add_argument('ic50_model_path', type=str, help='Path to pretrained IC50 model.') 48 | 49 | parser.add_argument( 50 | 'affinity_model_path', type=str, help='Path to pretrained affinity model.' 51 | ) 52 | parser.add_argument('--tox21_path', help='Optional path to Tox21 model.') 53 | 54 | parser.add_argument( 55 | 'params_path', type=str, help='Directory containing the model params JSON file.' 56 | ) 57 | parser.add_argument( 58 | 'encoder_params_path', 59 | type=str, 60 | help='directory containing the encoder parameters JSON file.' 61 | ) 62 | parser.add_argument('results_path', type=str, help='Path where results are saved.') 63 | parser.add_argument( 64 | 'unbiased_protein_path', 65 | type=str, 66 | help='Path where unbiased protein predictions are saved.' 67 | ) 68 | parser.add_argument( 69 | 'unbiased_omics_path', 70 | type=str, 71 | help='Path where unbiased omics predictions are saved.' 72 | ) 73 | 74 | parser.add_argument( 75 | 'site', type=str, help='Name of the cancer site for conditioning generation.' 76 | ) 77 | parser.add_argument('model_name', type=str, help='Name for the trained model.') 78 | 79 | args = parser.parse_args() 80 | 81 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 82 | logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) 83 | logger = logging.getLogger('train_paccmann_rl') 84 | logger_m = logging.getLogger('matplotlib') 85 | logger_m.setLevel(logging.WARNING) 86 | 87 | 88 | def main(*, parser_namespace): 89 | 90 | disable_rdkit_logging() 91 | 92 | # read the params json 93 | params = dict() 94 | with open(parser_namespace.params_path) as f: 95 | params.update(json.load(f)) 96 | 97 | with open(parser_namespace.encoder_params_path) as f: 98 | encoder_params = json.load(f) 99 | 100 | # results_path = params.get('results_path', parser_namespace.results_path) 101 | 102 | mol_model_path = params.get('mol_model_path', parser_namespace.mol_model_path) 103 | encoder_model_path = params.get( 104 | 'encoder_model_path', parser_namespace.encoder_model_path 105 | ) 106 | ic50_model_path = params.get('ic50_model_path', parser_namespace.ic50_model_path) 107 | omics_data_path = params.get('omics_data_path', parser_namespace.omics_data_path) 108 | affinity_model_path = params.get( 109 | 'affinity_model_path', parser_namespace.affinity_model_path 110 | ) 111 | protein_data_path = params.get( 112 | 'protein_data_path', parser_namespace.protein_data_path 113 | ) 114 | model_name = params.get( 115 | 'model_name', parser_namespace.model_name 116 | ) # yapf: disable 117 | 118 | unbiased_protein_path = params.get( 119 | 'unbiased_protein_path', parser_namespace.unbiased_protein_path 120 | ) # yapf: disable 121 | unbiased_omics_path = params.get( 122 | 'unbiased_omics_path', parser_namespace.unbiased_omics_path 123 | ) # yapf: disable 124 | site = params.get( 125 | 'site', parser_namespace.site 126 | ) # yapf: disable 127 | 128 | test_cell_line = params.get('test_cell_line', parser_namespace.test_cell_line) 129 | 130 | logger.info(f'Model with name {model_name} starts.') 131 | 132 | # passing optional paths to params to possibly update_reward_fn 133 | optional_reward_args = ['tox21_path', 'site'] 134 | for arg in optional_reward_args: 135 | if parser_namespace.__dict__[arg]: 136 | params[arg] = params.get(arg, parser_namespace.__dict__[arg]) 137 | 138 | omics_df = pd.read_csv(omics_data_path) 139 | 140 | protein_df = pd.read_csv(protein_data_path) 141 | protein_df.index = protein_df['entry_name'] 142 | 143 | # Restore SMILES Model 144 | with open(os.path.join(mol_model_path, 'model_params.json')) as f: 145 | mol_params = json.load(f) 146 | 147 | gru_encoder = StackGRUEncoder(mol_params) 148 | gru_decoder = StackGRUDecoder(mol_params) 149 | generator = TeacherVAE(gru_encoder, gru_decoder) 150 | generator.load( 151 | os.path.join( 152 | mol_model_path, f"weights/best_{params.get('smiles_metric', 'rec')}.pt" 153 | ), 154 | map_location=get_device() 155 | ) 156 | 157 | # Load languages 158 | 159 | generator_smiles_language = SMILESTokenizer( 160 | vocab_file=os.path.join(mol_model_path, 'vocab.json') 161 | ) 162 | 163 | generator.smiles_language = generator_smiles_language 164 | 165 | #load predictors 166 | with open(os.path.join(ic50_model_path, 'model_params.json')) as f: 167 | paccmann_params = json.load(f) 168 | 169 | paccmann_predictor = MODEL_FACTORY['mca'](paccmann_params) 170 | paccmann_predictor.load( 171 | os.path.join( 172 | ic50_model_path, f"weights/best_{params.get('ic50_metric', 'rmse')}_mca.pt" 173 | ), 174 | map_location=get_device() 175 | ) 176 | paccmann_predictor.eval() 177 | 178 | paccmann_smiles_language = SMILESLanguage.from_pretrained( 179 | pretrained_path=ic50_model_path 180 | ) 181 | 182 | paccmann_predictor._associate_language(paccmann_smiles_language) 183 | 184 | with open(os.path.join(affinity_model_path, 'model_params.json')) as f: 185 | protein_pred_params = json.load(f) 186 | 187 | protein_predictor = MODEL_FACTORY['bimodal_mca'](protein_pred_params) 188 | protein_predictor.load( 189 | os.path.join( 190 | affinity_model_path, 191 | f"weights/best_{params.get('p_metric', 'ROC-AUC')}_bimodal_mca.pt" 192 | ), 193 | map_location=get_device() 194 | ) 195 | protein_predictor.eval() 196 | 197 | affinity_smiles_language = SMILESLanguage.from_pretrained( 198 | pretrained_path=os.path.join(affinity_model_path, 'smiles_serial') 199 | ) 200 | affinity_protein_language = ProteinLanguage() 201 | 202 | protein_predictor._associate_language(affinity_smiles_language) 203 | protein_predictor._associate_language(affinity_protein_language) 204 | 205 | setsae = SetsAE(device, **encoder_params).to(device) 206 | 207 | setsae.load_state_dict(torch.load(encoder_model_path, map_location=get_device())) 208 | 209 | set_encoder = setsae.encoder 210 | set_encoder.latent_size = set_encoder.hidden_size_encoder 211 | 212 | ############################################# 213 | # Create a generator model that will be optimized 214 | gru_encoder_rl = StackGRUEncoder(mol_params) 215 | gru_decoder_rl = StackGRUDecoder(mol_params) 216 | generator_rl = TeacherVAE(gru_encoder_rl, gru_decoder_rl) 217 | generator_rl.load( 218 | os.path.join(mol_model_path, f"weights/best_{params.get('metric', 'rec')}.pt"), 219 | map_location=get_device() 220 | ) 221 | generator_rl.smiles_language = generator_smiles_language 222 | generator_rl.eval() 223 | # generator 224 | 225 | model_folder_name = test_cell_line + '_' + 'SetAE' 226 | 227 | learner = ReinforceMultiModalSets( 228 | generator_rl, set_encoder, protein_predictor, paccmann_predictor, protein_df, 229 | omics_df, params, generator_smiles_language, model_folder_name, logger, True 230 | ) 231 | 232 | train_omics = omics_df[omics_df['cell_line'] != test_cell_line]['cell_line'] 233 | train_protein = protein_df['entry_name'] 234 | 235 | # train_sets = list(product(train_omics, train_protein)) 236 | test_sets = list(product([test_cell_line], train_protein)) 237 | assert len(test_sets) == len(protein_df) 238 | 239 | unbiased_preds_ic50 = np.array( 240 | pd.read_csv(os.path.join(unbiased_omics_path, 241 | test_cell_line + '.csv'))['IC50'].values 242 | ) 243 | 244 | biased_efficacy_ratios, biased_affinity_ratios, tox_ratios = [], [], [] 245 | rewards, rl_losses = [], [] 246 | gen_mols, gen_prot, gen_cell = [], [], [] 247 | gen_affinity, gen_ic50, modes = [], [], [] 248 | proteins_tested = [] 249 | batch_size = params['batch_size'] 250 | 251 | logger.info(f'Model stored at {learner.model_path}') 252 | # total_train = len(train_sets) 253 | 254 | protein_name = None 255 | for epoch in range(1, params['epochs'] + 1): 256 | logger.info(f"Epoch {epoch:d}/{params['epochs']:d}") 257 | 258 | for step in range(1, params['steps'] + 1): 259 | cell_line = np.random.choice(train_omics) 260 | protein_name = np.random.choice(train_protein) 261 | # sample = np.random.randint(total_train) 262 | # cell_line, protein_name = train_sets[sample] 263 | 264 | logger.info(f'Current train cell: {cell_line}') 265 | logger.info(f'Current train protein: {protein_name}') 266 | 267 | rew, loss = learner.policy_gradient( 268 | cell_line, protein_name, epoch, batch_size 269 | ) 270 | logger.info( 271 | f"Step {step:d}/{params['steps']:d} \t loss={loss:.2f}, mean rew={rew:.2f}" 272 | ) 273 | 274 | rewards.append(rew.item()) 275 | rl_losses.append(loss) 276 | 277 | # Save model 278 | if epoch % 5 == 0: 279 | learner.save(f'gen_{epoch}.pt', f'enc_{epoch}.pt') 280 | 281 | # unbiased pred files are given by protein accession number, so convert entry_name 282 | protein_accession = protein_df.loc[protein_name, 'accession_number'] 283 | train_unbiased_preds_affinity = np.array( 284 | pd.read_csv( 285 | os.path.join(unbiased_protein_path, protein_accession + '.csv') 286 | )['affinity'].values 287 | ) 288 | 289 | train_unbiased_preds_ic50 = np.array( 290 | pd.read_csv(os.path.join(unbiased_omics_path, 291 | cell_line + '.csv'))['IC50'].values 292 | ) 293 | 294 | smiles, preds_affinity, preds_ic50, idx = learner.generate_compounds_and_evaluate( 295 | epoch, params['eval_batch_size'], protein_name, cell_line 296 | ) 297 | 298 | gs = [ 299 | s for i, s in enumerate(smiles) 300 | if preds_ic50[i] < learner.ic50_threshold and preds_affinity[i] > 0.5 301 | ] 302 | 303 | gp_ic50 = preds_ic50[(preds_ic50 < learner.ic50_threshold) 304 | & (preds_affinity > 0.5)] 305 | gp_affinity = preds_affinity[(preds_ic50 < learner.ic50_threshold) 306 | & (preds_affinity > 0.5)] 307 | 308 | for ic50, affinity, s in zip(gp_ic50, gp_affinity, gs): 309 | gen_mols.append(s) 310 | gen_cell.append(cell_line) 311 | gen_prot.append(protein_name) 312 | gen_affinity.append(affinity) 313 | gen_ic50.append(ic50) 314 | modes.append('train') 315 | 316 | plot_and_compare_proteins( 317 | train_unbiased_preds_affinity, preds_affinity, protein_name, epoch, 318 | learner.model_path, 'train', params['eval_batch_size'] 319 | ) 320 | 321 | plot_and_compare( 322 | train_unbiased_preds_ic50, preds_ic50, site, cell_line, epoch, 323 | learner.model_path, 'train', params['eval_batch_size'] 324 | ) 325 | 326 | # test_cell_line = np.random.choice(test_omics) 327 | # test_protein_name = np.random.choice(test_protein) 328 | if epoch > 10 and epoch % 5 == 0: 329 | for test_idx, test_sample in enumerate(test_sets): 330 | test_cell_line, test_protein_name = test_sample 331 | proteins_tested.append(test_protein_name) 332 | logger.info(f'EVAL cell: {test_cell_line}') 333 | logger.info(f'EVAL protein: {test_protein_name}') 334 | 335 | test_protein_accession = protein_df.loc[test_protein_name, 336 | 'accession_number'] 337 | unbiased_preds_affinity = np.array( 338 | pd.read_csv( 339 | os.path.join( 340 | unbiased_protein_path, test_protein_accession + '.csv' 341 | ) 342 | )['affinity'].values 343 | ) 344 | 345 | smiles, preds_affinity, preds_ic50, idx = ( 346 | learner.generate_compounds_and_evaluate( 347 | epoch, params['eval_batch_size'], test_protein_name, 348 | test_cell_line 349 | ) 350 | ) 351 | 352 | gs = [ 353 | s for i, s in enumerate(smiles) if 354 | preds_ic50[i] < learner.ic50_threshold and preds_affinity[i] > 0.5 355 | ] 356 | 357 | gp_ic50 = preds_ic50[(preds_ic50 < learner.ic50_threshold) 358 | & (preds_affinity > 0.5)] 359 | gp_affinity = preds_affinity[(preds_ic50 < learner.ic50_threshold) 360 | & (preds_affinity > 0.5)] 361 | 362 | for ic50, affinity, s in zip(gp_ic50, gp_affinity, gs): 363 | gen_mols.append(s) 364 | gen_cell.append(test_cell_line) 365 | gen_prot.append(test_protein_name) 366 | gen_affinity.append(affinity) 367 | gen_ic50.append(ic50) 368 | modes.append('test') 369 | 370 | inds = np.argsort(gp_ic50)[::-1] 371 | for i in inds[:5]: 372 | logger.info( 373 | f'Epoch {epoch:d}, generated {gs[i]} against ' 374 | f'{test_protein_name} and {test_cell_line}.\n' 375 | f'Predicted IC50 = {gp_ic50[i]}, Predicted Affinity = {gp_affinity[i]}.' 376 | ) 377 | 378 | plot_and_compare( 379 | unbiased_preds_ic50, preds_ic50, site, test_cell_line, epoch, 380 | learner.model_path, f'test_{test_protein_name}', 381 | params['eval_batch_size'] 382 | ) 383 | 384 | plot_and_compare_proteins( 385 | unbiased_preds_affinity, preds_affinity, test_protein_name, epoch, 386 | learner.model_path, 'test', params['eval_batch_size'] 387 | ) 388 | 389 | biased_affinity_ratios.append( 390 | np.round( 391 | 100 * (np.sum(preds_affinity > 0.5) / len(preds_affinity)), 1 392 | ) 393 | ) 394 | 395 | biased_efficacy_ratios.append( 396 | np.round( 397 | 100 * 398 | (np.sum(preds_ic50 < learner.ic50_threshold) / len(preds_ic50)), 399 | 1 400 | ) 401 | ) 402 | 403 | all_toxes = np.array([learner.tox21(s) for s in smiles]) 404 | tox_ratios.append( 405 | np.round(100 * (np.sum(all_toxes == 1.) / len(all_toxes)), 1) 406 | ) 407 | logger.info(f'Percentage of non-toxic compounds {tox_ratios[-1]}') 408 | 409 | toxes = [learner.tox21(s) for s in gen_mols] 410 | # Save results (good molecules!) in DF 411 | df = pd.DataFrame( 412 | { 413 | 'protein': gen_prot, 414 | 'cell_line': gen_cell, 415 | 'SMILES': gen_mols, 416 | 'IC50': gen_ic50, 417 | 'Binding probability': gen_affinity, 418 | 'mode': modes, 419 | 'Tox21': toxes 420 | } 421 | ) 422 | 423 | df.to_csv(os.path.join(learner.model_path, 'results', 'generated.csv')) 424 | # Plot loss development 425 | loss_df = pd.DataFrame({'loss': rl_losses, 'rewards': rewards}) 426 | loss_df.to_csv( 427 | learner.model_path + '/results/loss_reward_evolution.csv' 428 | ) 429 | 430 | pd.DataFrame( 431 | { 432 | 'proteins': proteins_tested, 433 | 'efficacy_ratio': biased_efficacy_ratios, 434 | 'affinity_ratio': biased_affinity_ratios, 435 | 'tox_ratio': tox_ratios 436 | } 437 | ).to_csv(learner.model_path + '/results/ratios.csv') 438 | 439 | rewards_p_all = loss_df['rewards'] 440 | losses_p_all = loss_df['loss'] 441 | plot_loss( 442 | losses_p_all, rewards_p_all, params['epochs'], learner.model_path, rolling=5 443 | ) 444 | 445 | 446 | if __name__ == '__main__': 447 | main(parser_namespace=args) 448 | -------------------------------------------------------------------------------- /examples/set_autoencoder/conda.yml: -------------------------------------------------------------------------------- 1 | name: molecule_generation 2 | channels: 3 | - https://conda.anaconda.org/rdkit 4 | dependencies: 5 | - rdkit=2019.03.1 6 | - python>=3.6,<3.8 7 | - pip>=19.1 8 | - pip: 9 | - pytoda @ git+https://git@github.com/PaccMann/paccmann_datasets@0.2.4 10 | - paccmann_chemistry @ git+https://git@github.com/PaccMann/paccmann_chemistry@0.0.4 11 | - paccmann_predictor @ git+https://github.com/PaccMann/paccmann_predictor@0.0.3 12 | - paccmann_generator @ git+https://github.com/PaccMann/paccmann_generator@0.0.2 13 | - paccmann_omics @ git+https://github.com/PaccMann/paccmann_omics@0.0.1 14 | - fdsa @ git+https://github.com/PaccMann/fdsa@0.0.1 15 | - numpy>=1.14.3 16 | - pandas>=0.24.1 17 | - torch>=1.0.1 18 | - matplotlib>=2.2.2 19 | - seaborn>=0.9.0 20 | -------------------------------------------------------------------------------- /examples/set_autoencoder/reconstruct_128D_cancer_data/matching_params.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_size": 500000, 3 | "valid_size": 50000, 4 | "test_size": 50000, 5 | "input_size": 128, 6 | "seeds": { 7 | "train1": "None", 8 | "train2": "None", 9 | "train_true": "None", 10 | "valid1": "None", 11 | "valid2": "None", 12 | "valid_true": "None", 13 | "test1": "None", 14 | "test2": "None", 15 | "test_true": "None" 16 | }, 17 | "dataset_type": "permute", 18 | "padding_value": 4.0, 19 | "noise_std": 0.0, 20 | "comment_noise":"trained on 0.2 noise std", 21 | "distribution_type": "normal", 22 | "distribution_args": { 23 | "loc": 1.0, 24 | "scale": 0.5 25 | }, 26 | "cost_metric": "p-norm", 27 | "cost_metric_args": { 28 | "p": 2 29 | }, 30 | "cell": "GRU", 31 | "layers": 1, 32 | "hidden_size": 512, 33 | "bidirectional": "False", 34 | "return_sequences": "True", 35 | "batch_first": "True", 36 | "fc_layers": 2, 37 | "fc_units": [ 38 | 128, 39 | 5 40 | ], 41 | "fc_activation": [ 42 | "lrelu", 43 | "None" 44 | ], 45 | "train_batch": 128, 46 | "valid_batch": 128, 47 | "test_batch": 128, 48 | "epochs": 30, 49 | "lr": 0.001, 50 | "max_length": 5, 51 | "min_length": 2, 52 | "loss": "ce_row", 53 | "ce_type": "row", 54 | "sinkhorn_iter": 1, 55 | "temperature": 1.0 56 | } 57 | -------------------------------------------------------------------------------- /examples/set_autoencoder/reconstruct_128D_cancer_data/setsae_setm_train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import logging 4 | import os 5 | import time 6 | 7 | import matplotlib.pyplot as plt 8 | import numpy as np 9 | import seaborn as sns 10 | import torch 11 | from fdsa.datasets.torch_dataset import Collate, SetsDataset 12 | from fdsa.models.sets_autoencoder import SetsAE 13 | from fdsa.utils.helper import (cpuStats, get_gpu_memory_map, setup_logger) 14 | from fdsa.utils.hyperparameters import LR_SCHEDULER_FACTORY 15 | from fdsa.utils.loss_setae import SetAELoss 16 | from fdsa.utils.mapper import MapperSetsAE 17 | from fdsa.utils.setsae_setm import NetworkMapperSetsAE 18 | from torch.utils.data import DataLoader 19 | 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument( 22 | 'model_path', 23 | type=str, 24 | help='Path where the pre-trained matching network is saved.' 25 | ) 26 | parser.add_argument( 27 | 'results_path', type=str, help='Path to save the results, logs and best model.' 28 | ) 29 | 30 | parser.add_argument( 31 | 'training_params', type=str, help='Path to the training parameters json file.' 32 | ) 33 | 34 | parser.add_argument( 35 | 'matching_params', 36 | type=str, 37 | help='Path to the matching network parameters json file.' 38 | ) 39 | 40 | parser.add_argument('train_data_path', type=str, help='Path to the training data.') 41 | parser.add_argument('valid_data_path', type=str, help='Path to the validation data.') 42 | parser.add_argument('test_data_path', type=str, help='Path to the testing data.') 43 | 44 | # get device 45 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 46 | 47 | 48 | def main( 49 | model_path: str, 50 | results_path: str, 51 | training_params: str, 52 | matching_params: str, 53 | train_data_path: str, 54 | valid_data_path: str, 55 | test_data_path: str, 56 | ): 57 | """Executes the Set Autoencoder for the chosen matching method in 58 | reconstructing the given data. 59 | 60 | Args: 61 | model_path (str): Path where the pre-trained matching network is saved 62 | results_path (str): Path to save the results, logs and best model. 63 | training_params (str): Path to the training parameters json file. 64 | matching_params (str): Path to the matching network parameters json file. 65 | train_data_path (str): Path to the training data. 66 | valid_data_path (str): Path to the validation data. 67 | test_data_path (str): Path to the testing data. 68 | """ 69 | 70 | # setup logging 71 | logger = setup_logger( 72 | 'sets', os.path.join(results_path, "setsae_setm.log"), logging.DEBUG 73 | ) 74 | logger_mem = setup_logger( 75 | 'memory', os.path.join(results_path, 'logging_memory_time.log') 76 | ) 77 | 78 | clrs = sns.color_palette('Set2', 2) 79 | fig, ax = plt.subplots(constrained_layout=True) 80 | fig.suptitle( 81 | 'Train and Validation Loss of Sets AE on Latent Embeddings of GEP and Proteins' 82 | ) 83 | ax.set_xlabel('Epochs') 84 | ax.set_ylabel('Loss') 85 | 86 | save_here = os.path.join(results_path, 'setsae') 87 | model = None 88 | if model is not None: 89 | del (model) 90 | 91 | with open(training_params, 'r') as readjson: 92 | train_params = json.load(readjson) 93 | 94 | with open(matching_params, 'r') as readjson: 95 | match_params = json.load(readjson) 96 | 97 | max_length = train_params.get('max_length', 5) 98 | dim = train_params.get('input_size', 128) 99 | padding_value = train_params.get('padding_value', 4.0) 100 | 101 | train_dataset = SetsDataset(train_data_path, device) 102 | valid_dataset = SetsDataset(valid_data_path, device) 103 | test_dataset = SetsDataset(test_data_path, device) 104 | 105 | collator = Collate(max_length, dim, padding_value, device) 106 | 107 | train_loader = DataLoader( 108 | train_dataset, 109 | batch_size=train_params['train_batch'], 110 | shuffle=True, 111 | collate_fn=collator 112 | ) 113 | 114 | valid_loader = DataLoader( 115 | valid_dataset, 116 | batch_size=train_params['valid_batch'], 117 | shuffle=True, 118 | collate_fn=collator 119 | ) 120 | 121 | test_loader = DataLoader( 122 | test_dataset, 123 | batch_size=train_params['test_batch'], 124 | shuffle=False, 125 | collate_fn=collator 126 | ) 127 | 128 | model = SetsAE(device, **train_params).to(device) 129 | 130 | mapper = NetworkMapperSetsAE( 131 | train_params['mapper'], model_path, match_params, device 132 | ) 133 | #mapper = MapperSetsAE(train_params['matcher'], train_params['p'], device) 134 | optimiser = torch.optim.Adam(model.parameters(), lr=train_params['lr']) 135 | 136 | lr_scheduler = LR_SCHEDULER_FACTORY[train_params['scheduler'] 137 | ](optimiser, *train_params['lr_args']) 138 | 139 | loss = SetAELoss(train_params['loss'], device) 140 | min_loss = np.inf 141 | epochs = train_params['epochs'] 142 | 143 | plot_train = [] 144 | plot_valid = [] 145 | for epoch in range(epochs): 146 | 147 | model.train() 148 | logger.info("=== Epoch [{}/{}]".format(epoch + 1, epochs)) 149 | 150 | if epoch == 1: 151 | tic = time.time() 152 | 153 | train_time = 0 154 | for idx, (x_train, train_lengths) in enumerate(train_loader): 155 | 156 | x_train = x_train.to(device) 157 | 158 | pred_train, prob_train = model(x_train, x_train.size(1), train_lengths) 159 | 160 | t0 = time.time() 161 | mapped_outputs_train, mapped_prob_train, train12 = mapper( 162 | x_train, pred_train, prob_train 163 | ) 164 | torch.cuda.current_stream().synchronize() 165 | t1 = time.time() 166 | 167 | if epoch == 0 and idx == 0: 168 | logger_mem.info(cpuStats()) 169 | logger_mem.info(print(get_gpu_memory_map())) 170 | logger_mem.info('x_train:{}'.format(x_train.size())) 171 | logger_mem.info('pred_train:{}'.format(pred_train.size())) 172 | logger_mem.info('prob_train:{}'.format(prob_train.size())) 173 | logger_mem.info('train_lengths:{}'.format(train_lengths.size())) 174 | logger_mem.info(torch.cuda.memory_allocated()) 175 | 176 | train_time += t1 - t0 177 | 178 | train_loss = loss( 179 | x_train, mapped_outputs_train, mapped_prob_train, train_lengths 180 | ) 181 | 182 | optimiser.zero_grad() 183 | train_loss.backward() 184 | optimiser.step() 185 | 186 | if epoch == 1: 187 | toc = time.time() 188 | logger_mem.info("Total training time for one epoch = {}".format(toc - tic)) 189 | 190 | logger_mem.info( 191 | "Total mapping time in training epoch {} = {}".format(epoch, train_time) 192 | ) 193 | plot_train.append(train_loss.detach()) 194 | logger.info("Train Loss = {}".format(train_loss.detach())) 195 | 196 | model.eval() 197 | avg_valid_loss = 0 198 | for idx, (x_valid, valid_lengths) in enumerate(valid_loader): 199 | 200 | x_valid = x_valid.to(device) 201 | 202 | pred_valid, prob_valid = model(x_valid, x_valid.size(1), valid_lengths) 203 | 204 | mapped_outputs_valid, mapped_prob_valid, valid12 = mapper( 205 | x_valid, pred_valid, prob_valid 206 | ) 207 | 208 | valid_loss = loss( 209 | x_valid, mapped_outputs_valid, mapped_prob_valid, valid_lengths 210 | ) 211 | 212 | avg_valid_loss = (avg_valid_loss * idx + valid_loss.detach()) / (idx + 1) 213 | 214 | plot_valid.append(avg_valid_loss.detach()) 215 | logger.info("Avg Valid Loss = {}".format(avg_valid_loss.detach())) 216 | if avg_valid_loss < min_loss: 217 | min_loss = avg_valid_loss 218 | torch.save( 219 | { 220 | 'epoch': epoch, 221 | 'model_state_dict': model.state_dict(), 222 | 'optimizer_state_dict': optimiser.state_dict(), 223 | 'train_loss': train_loss, 224 | 'valid_loss': valid_loss 225 | # 'learning_rate': lr_expscheduler.get_last_lr()[0] 226 | }, 227 | save_here 228 | ) 229 | lr_scheduler.step() 230 | 231 | torch.save(plot_train, os.path.join(results_path, 'train_loss')) 232 | torch.save(plot_valid, os.path.join(results_path, 'avg_valid_loss')) 233 | 234 | ax.plot(range(len(plot_train)), plot_train, color=clrs[0], label="Training Loss") 235 | ax.plot(range(len(plot_valid)), plot_valid, color=clrs[1], label="Validation Loss") 236 | ax.legend() 237 | fig.savefig(os.path.join(results_path, 'setsAE.png')) 238 | 239 | avg_test_loss = 0 240 | model.eval() 241 | test_predlist = [] 242 | test_truelist = [] 243 | test12_list = [] 244 | mapping_time = 0 245 | 246 | for idx, (x_test, test_lengths) in enumerate(test_loader): 247 | 248 | x_test = x_test.to(device) 249 | 250 | pred_test, prob_test = model(x_test, x_test.size(1), test_lengths) 251 | 252 | tic = time.time() 253 | mapped_outputs_test, mapped_prob_test, test12 = mapper( 254 | x_test, pred_test, prob_test 255 | ) 256 | toc = time.time() 257 | mapping_time += toc - tic 258 | 259 | test_loss = loss(x_test, mapped_outputs_test, mapped_prob_test, test_lengths) 260 | 261 | test_predlist.append(mapped_outputs_test.detach().cpu().numpy()) 262 | test_truelist.append(x_test.detach().cpu().numpy()) 263 | test12_list.append(test12) 264 | 265 | avg_test_loss = (avg_test_loss * idx + test_loss.detach()) / (idx + 1) 266 | 267 | torch.save(test_predlist, os.path.join(results_path, 'reconstructions')) 268 | torch.save(test_truelist, os.path.join(results_path, 'original')) 269 | torch.save(test12_list, os.path.join(results_path, 'test12')) 270 | 271 | logger.info("Avg Test Loss = {}".format(avg_test_loss.detach())) 272 | logger.info("Total Mapping Time = {} seconds".format(mapping_time)) 273 | 274 | batches = test_dataset.__len__() / train_params['test_batch'] 275 | min_batch_size = (batches - int(batches)) * train_params['test_batch'] 276 | 277 | for i in range(train_params['num_plots']): 278 | 279 | batch = torch.randint(0, len(test_loader), (1, )) 280 | 281 | if batch == len(test_loader) - 1: 282 | sample = torch.randint(0, int(min_batch_size), (1, )) 283 | else: 284 | sample = torch.randint(0, train_params['test_batch'], (1, )) 285 | 286 | fig, (ax1, ax2) = plt.subplots(2, 1, sharex=True, constrained_layout=True) 287 | 288 | for dim in range(128): 289 | sns.kdeplot(test_truelist[batch][sample][:, dim], ax=ax1) 290 | sns.kdeplot(test_predlist[batch][sample][:, dim], ax=ax2) 291 | 292 | ax1.set( 293 | ylabel='Density (unstandardised)', 294 | title='Original Sample KDE of n=128 Latent Dimensions' 295 | ) 296 | ax2.set( 297 | xlabel='Sample Values', 298 | ylabel='Density (unstandardised)', 299 | title='Reconstructed Sample KDE of n=128 Latent Dimensions' 300 | ) 301 | fig.savefig(os.path.join(results_path, f"Test_{i+1}.png")) 302 | 303 | plt.close() 304 | 305 | 306 | if __name__ == '__main__': 307 | args = parser.parse_args() 308 | main( 309 | args.model_path, args.results_path, args.training_params, args.matching_params, 310 | args.train_data_path, args.valid_data_path, args.test_data_path 311 | ) 312 | -------------------------------------------------------------------------------- /examples/set_autoencoder/reconstruct_128D_cancer_data/train_params.json: -------------------------------------------------------------------------------- 1 | { 2 | "cell": "pLSTM", 3 | "input_size": 128, 4 | "hidden_size_linear": 256, 5 | "hidden_size_encoder": 256, 6 | "hidden_size_decoder": 256, 7 | "mapper": "rnn", 8 | "matcher": "HM", 9 | "p": 2, 10 | "loss": "CrossEntropy", 11 | "epochs": 20, 12 | "lr": 0.0001, 13 | "scheduler": "step", 14 | "lr_args": [ 15 | 10, 16 | 1.0 17 | ], 18 | "train_batch": 128, 19 | "test_batch": 128, 20 | "valid_batch": 128, 21 | "test_size": 50000, 22 | "train_size": 499082, 23 | "valid_size": 50000, 24 | "padding_value": 4.0, 25 | "masking": "False", 26 | "num_plots": 10, 27 | "comment" : "no masking of output, matching done on stacked output directly, and mapped output also uses stacked output" 28 | } 29 | -------------------------------------------------------------------------------- /examples/set_autoencoder/shapes/shapes_train.json: -------------------------------------------------------------------------------- 1 | { 2 | "cell": "pLSTM", 3 | "input_size": 2, 4 | "hidden_size_linear": 66, 5 | "hidden_size_encoder": 66, 6 | "hidden_size_decoder": 66, 7 | "matcher": "HM", 8 | "p": 2, 9 | "loss": "CrossEntropy", 10 | "epochs": 15, 11 | "lr": 0.0068, 12 | "scheduler": "exp", 13 | "lr_args": [ 14 | 0.69 15 | ], 16 | "train_batch": 128, 17 | "test_batch": 128, 18 | "valid_batch": 128, 19 | "masking": "False", 20 | "num_plots": 10 21 | } -------------------------------------------------------------------------------- /examples/set_autoencoder/shapes/shapes_train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import logging 4 | import os 5 | import sys 6 | 7 | import matplotlib.pyplot as plt 8 | import numpy as np 9 | import seaborn as sns 10 | import torch 11 | from fdsa.datasets.torch_dataset import Collate, ToySetsDataset 12 | from fdsa.models.sets_autoencoder import SetsAE 13 | from fdsa.utils.hyperparameters import LR_SCHEDULER_FACTORY 14 | from fdsa.utils.loss_setae import SetAELoss 15 | from fdsa.utils.mapper import MapperSetsAE 16 | from torch.utils.data import DataLoader 17 | 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument( 20 | 'model_path', type=str, help='Path to save the best performing model.' 21 | ) 22 | parser.add_argument( 23 | 'results_path', type=str, help='Path to save the training and validation losses.' 24 | ) 25 | 26 | parser.add_argument('training_data_path', type=str, help='Path to the training data.') 27 | parser.add_argument( 28 | 'validation_data_path', type=str, help='Path to the validation data.' 29 | ) 30 | parser.add_argument('testing_data_path', type=str, help='Path to the testing data.') 31 | parser.add_argument( 32 | 'training_params', type=str, help='Path to the training parameters json file.' 33 | ) 34 | 35 | # get device 36 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 37 | 38 | 39 | def main( 40 | model_path: str, results_path: str, training_data_path: str, 41 | validation_data_path: str, testing_data_path: str, training_params: str 42 | ): 43 | """Main function for reconstructing shapes. 44 | 45 | Args: 46 | model_path (str): Path to save the best performing model. 47 | results_path (str): Path to save the training and validation losses. 48 | training_data_path (str): Path to the training data. 49 | validation_data_path (str): Path to the validation data. 50 | testing_data_path (str): Path to the testing data. 51 | training_params (str): Path to the training parameters json file. 52 | """ 53 | 54 | # setup logging 55 | logging.basicConfig( 56 | handlers=[ 57 | logging.FileHandler(os.path.join(results_path, "sets_autoencoder.log")), 58 | logging.StreamHandler(sys.stdout) 59 | ], 60 | ) 61 | logger = logging.getLogger('shapes') 62 | logger.setLevel(logging.DEBUG) 63 | 64 | clrs = sns.color_palette('Set2', 2) 65 | fig2, ax2 = plt.subplots(constrained_layout=True) 66 | ax2.set_title('Train and Validation Loss of Sets AE on Shapes Dataset') 67 | ax2.set_xlabel('Epochs') 68 | ax2.set_ylabel('Loss') 69 | 70 | save_here = os.path.join(model_path, 'shapes_setsae') 71 | model = None 72 | if model is not None: 73 | del (model) 74 | 75 | with open(training_params, 'r') as readjson: 76 | params = json.load(readjson) 77 | 78 | max_length = params.get('max_length', 33) 79 | dim = params.get('input_size', 2) 80 | padding_value = params.get('padding_value', 2.0) 81 | 82 | collator = Collate(max_length, dim, padding_value, device) 83 | 84 | # get training data 85 | train_dataset = ToySetsDataset(training_data_path, ['x', 'y'], ['label'], ['ID']) 86 | train_loader = DataLoader( 87 | train_dataset, 88 | batch_size=params['train_batch'], 89 | shuffle=True, 90 | collate_fn=collator 91 | ) 92 | 93 | valid_dataset = ToySetsDataset(validation_data_path, ['x', 'y'], ['label'], ['ID']) 94 | valid_loader = DataLoader( 95 | valid_dataset, 96 | batch_size=params['valid_batch'], 97 | shuffle=True, 98 | collate_fn=collator 99 | ) 100 | 101 | test_dataset = ToySetsDataset(testing_data_path, ['x', 'y'], ['label'], ['ID']) 102 | test_loader = DataLoader( 103 | test_dataset, 104 | batch_size=params['test_batch'], 105 | shuffle=True, 106 | collate_fn=collator 107 | ) 108 | 109 | model = SetsAE(device, **params).to(device) 110 | mapper = MapperSetsAE(params['matcher'], params['p'], device) 111 | loss = SetAELoss(params['loss'], device) 112 | 113 | optimiser = torch.optim.Adam(model.parameters(), lr=params['lr']) 114 | 115 | lr_scheduler = LR_SCHEDULER_FACTORY[params['scheduler'] 116 | ](optimiser, *params['lr_args']) 117 | 118 | min_loss = np.inf 119 | epochs = params['epochs'] 120 | 121 | plot_train = [] 122 | plot_valid = [] 123 | for epoch in range(epochs): 124 | 125 | model.train() 126 | logger.info("=== Epoch [{}/{}]".format(epoch + 1, epochs)) 127 | 128 | for idx, (x_batch, batch_lengths) in enumerate(train_loader): 129 | 130 | x_batch = x_batch.to(device) 131 | 132 | pred_train, prob_train = model(x_batch, x_batch.size(1), batch_lengths) 133 | 134 | mapped_outputs, mapped_prob, _ = mapper(x_batch, pred_train, prob_train) 135 | 136 | train_loss = loss(x_batch, mapped_outputs, mapped_prob, batch_lengths) 137 | 138 | optimiser.zero_grad() 139 | train_loss.backward() 140 | optimiser.step() 141 | 142 | cpu_train_loss = train_loss.detach().cpu().numpy() 143 | plot_train.append(cpu_train_loss) 144 | logger.info("Train Loss = {}".format(cpu_train_loss)) 145 | 146 | model.eval() 147 | avg_valid_loss = 0 148 | for idx, (x_valid, batch_lengths) in enumerate(valid_loader): 149 | 150 | x_valid = x_valid.to(device) 151 | 152 | pred_valid, prob_valid = model(x_valid, x_valid.size(1), batch_lengths) 153 | 154 | mapped_outputs_valid, mapped_prob_valid, _ = mapper( 155 | x_valid, pred_valid, prob_valid 156 | ) 157 | 158 | valid_loss = loss( 159 | x_valid, mapped_outputs_valid, mapped_prob_valid, batch_lengths 160 | ) 161 | 162 | avg_valid_loss = (avg_valid_loss * idx + valid_loss.detach()) / (idx + 1) 163 | 164 | cpu_avg_valid_loss = avg_valid_loss.detach().cpu().numpy() 165 | plot_valid.append(cpu_avg_valid_loss) 166 | logger.info("Avg Valid Loss = {}".format(cpu_avg_valid_loss)) 167 | 168 | if avg_valid_loss < min_loss: 169 | min_loss = avg_valid_loss 170 | torch.save( 171 | { 172 | 'epoch': epoch, 173 | 'model_state_dict': model.state_dict(), 174 | 'optimizer_state_dict': optimiser.state_dict(), 175 | 'train_loss': train_loss, 176 | 'valid_loss': valid_loss 177 | }, save_here 178 | ) 179 | if params['scheduler'] == 'plateau': 180 | lr_scheduler.step(valid_loss) 181 | else: 182 | lr_scheduler.step() 183 | 184 | np.save(os.path.join(results_path, 'train_loss'), plot_train) 185 | np.save(os.path.join(results_path, 'avg_valid_loss'), plot_valid) 186 | ax2.plot(range(len(plot_train)), plot_train, color=clrs[0], label="Train Loss") 187 | ax2.plot(range(len(plot_valid)), plot_valid, color=clrs[1], label="Valid Loss") 188 | ax2.legend() 189 | fig2.savefig(os.path.join(results_path, 'setsAE.png')) 190 | 191 | avg_test_loss = 0 192 | model.eval() 193 | test_predlist = [] 194 | test_truelist = [] 195 | for idx, (x_test, batch_lengths) in enumerate(test_loader): 196 | 197 | x_test = x_test.to(device) 198 | 199 | pred_test, prob_test = model(x_test, x_test.size(1), batch_lengths) 200 | 201 | mapped_outputs_test, mapped_prob_test, _ = mapper(x_test, pred_test, prob_test) 202 | 203 | test_loss = loss(x_test, mapped_outputs_test, mapped_prob_test, batch_lengths) 204 | 205 | test_predlist.append(mapped_outputs_test.detach().cpu().numpy()) 206 | test_truelist.append(x_test.detach().cpu().numpy()) 207 | 208 | avg_test_loss = (avg_test_loss * idx + test_loss.detach()) / (idx + 1) 209 | 210 | np.save(os.path.join(results_path, 'reconstructions'), test_predlist) 211 | np.save(os.path.join(results_path, 'original'), test_truelist) 212 | logger.info("Avg Test Loss = {}".format(avg_test_loss.detach().cpu().numpy())) 213 | 214 | batches = test_dataset.__len__() / params['test_batch'] 215 | min_batch_size = (batches - int(batches)) * params['test_batch'] 216 | for i in range(params['num_plots']): 217 | 218 | batch = torch.randint(0, len(test_loader), (1, )) 219 | 220 | if batch == len(test_loader) - 1: 221 | sample = torch.randint(0, int(min_batch_size), (1, )) 222 | else: 223 | sample = torch.randint(0, params['test_batch'], (1, )) 224 | 225 | fig, ax = plt.subplots() 226 | ax.set_title("Visualisation of the Original 2D Shape and its Reconstruction") 227 | ax.set_xlim([-0.5, 0.5]) 228 | ax.set_ylim([-0.5, 0.5]) 229 | ax.scatter( 230 | test_truelist[batch][sample][:, 0], 231 | test_truelist[batch][sample][:, 1], 232 | label='Original Sample' 233 | ) 234 | ax.scatter( 235 | test_predlist[batch][sample][:, 0], 236 | test_predlist[batch][sample][:, 1], 237 | label='Reconstructed Sample' 238 | ) 239 | ax.legend() 240 | fig.savefig(os.path.join(results_path, f"Test_{i+1}.png")) 241 | plt.close() 242 | 243 | 244 | if __name__ == '__main__': 245 | args = parser.parse_args() 246 | 247 | main( 248 | args.model_path, args.results_path, args.training_data_path, 249 | args.validation_data_path, args.testing_data_path, args.training_params 250 | ) 251 | -------------------------------------------------------------------------------- /examples/set_matching/conda.yml: -------------------------------------------------------------------------------- 1 | name: molecule_generation 2 | channels: 3 | - https://conda.anaconda.org/rdkit 4 | dependencies: 5 | - rdkit=2019.03.1 6 | - python>=3.6,<3.8 7 | - pip>=19.1 8 | - pip: 9 | - pytoda @ git+https://git@github.com/PaccMann/paccmann_datasets@0.2.4 10 | - fdsa @ git+https://github.com/PaccMann/fdsa@0.0.1 11 | - numpy>=1.14.3 12 | - pandas>=0.24.1 13 | - torch>=1.0.1 14 | - matplotlib>=2.2.2 15 | - seaborn>=0.9.0 16 | -------------------------------------------------------------------------------- /examples/set_matching/rnn/birnn_params.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_size": 500000, 3 | "valid_size": 50000, 4 | "test_size": 50000, 5 | "input_size": 128, 6 | "seeds": { 7 | "train1": "None", 8 | "train2": "None", 9 | "train_true": "None", 10 | "valid1": "None", 11 | "valid2": "None", 12 | "valid_true": "None", 13 | "test1": "None", 14 | "test2": "None", 15 | "test_true": "None" 16 | }, 17 | "dataset_type": "permuted", 18 | "padding_value": 6.0, 19 | "noise_std": 0.5, 20 | "distribution_type": "normal", 21 | "distribution_args": { 22 | "loc": 0.0, 23 | "scale": 1.0 24 | }, 25 | "cost_metric": "p-norm", 26 | "cost_metric_args": { 27 | "p": 2 28 | }, 29 | "cell": "GRU", 30 | "layers": 1, 31 | "hidden_size": 512, 32 | "bidirectional": "True", 33 | "return_sequences": "True", 34 | "batch_first": "False", 35 | "fc_layers": 2, 36 | "fc_units": [ 37 | 128, 38 | 5 39 | ], 40 | "fc_activation": [ 41 | "lrelu", 42 | "None" 43 | ], 44 | "train_batch": 128, 45 | "valid_batch": 128, 46 | "test_batch": 128, 47 | "epochs": 20, 48 | "lr": 0.001, 49 | "max_length": 5, 50 | "min_length": 2, 51 | "loss": "ce_row", 52 | "ce_type": "row", 53 | "sinkhorn_iter": 1, 54 | "temperature": 1.0 55 | } 56 | -------------------------------------------------------------------------------- /examples/set_matching/rnn/rnn_params.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_size": 500000, 3 | "valid_size": 50000, 4 | "test_size": 50000, 5 | "input_size": 128, 6 | "seeds": { 7 | "train1": "None", 8 | "train2": "None", 9 | "train_true": "None", 10 | "valid1": "None", 11 | "valid2": "None", 12 | "valid_true": "None", 13 | "test1": "None", 14 | "test2": "None", 15 | "test_true": "None" 16 | }, 17 | "dataset_type": "permuted", 18 | "padding_value": 6.0, 19 | "noise_std": 0.5, 20 | "distribution_type": "normal", 21 | "distribution_args": { 22 | "loc": 0.0, 23 | "scale": 1.0 24 | }, 25 | "cost_metric": "p-norm", 26 | "cost_metric_args": { 27 | "p": 2 28 | }, 29 | "cell": "GRU", 30 | "layers": 1, 31 | "hidden_size": 512, 32 | "bidirectional": "False", 33 | "return_sequences": "True", 34 | "batch_first": "False", 35 | "fc_layers": 2, 36 | "fc_units": [ 37 | 128, 38 | 5 39 | ], 40 | "fc_activation": [ 41 | "lrelu", 42 | "None" 43 | ], 44 | "train_batch": 128, 45 | "valid_batch": 128, 46 | "test_batch": 128, 47 | "epochs": 20, 48 | "lr": 0.001, 49 | "max_length": 5, 50 | "min_length": 2, 51 | "loss": "ce_row", 52 | "ce_type": "row", 53 | "sinkhorn_iter": 1, 54 | "temperature": 1.0 55 | } 56 | -------------------------------------------------------------------------------- /examples/set_matching/seq2seq/seq2seq_params.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_size": 500000, 3 | "valid_size": 50000, 4 | "test_size": 50000, 5 | "seeds": { 6 | "train1": "None", 7 | "train2": "None", 8 | "train_true": "None", 9 | "valid1": "None", 10 | "valid2": "None", 11 | "valid_true": "None", 12 | "test1": "None", 13 | "test2": "None", 14 | "test_true": "None" 15 | }, 16 | "padding_value":6.0, 17 | "dataset_type": "permuted", 18 | "noise_std":0.5, 19 | "distribution_type": "normal", 20 | "distribution_args": { 21 | "loc": 0.0, 22 | "scale": 1.0 23 | }, 24 | "cost_metric": "p-norm", 25 | "cost_metric_args": { 26 | "p": 2 27 | }, 28 | "cell": "GRU", 29 | "layers": 1, 30 | "hidden_size": 512, 31 | "fc_layers": 2, 32 | "fc_units": [ 33 | 128, 34 | 5 35 | ], 36 | "fc_activation": [ 37 | "lrelu", 38 | "None" 39 | ], 40 | "input_size": 128, 41 | "output_size": 5, 42 | "train_batch": 256, 43 | "valid_batch": 128, 44 | "test_batch": 128, 45 | "batch_first": "False", 46 | "bidirectional": "False", 47 | "return_sequences": "True", 48 | "epochs": 20, 49 | "lr": 0.0001, 50 | "max_length": 5, 51 | "min_length": 2, 52 | "loss": "ce_row", 53 | "ce_type": "row", 54 | "sinkhorn_iter": 1, 55 | "temperature": 1.0, 56 | "patience" : 6 57 | } 58 | -------------------------------------------------------------------------------- /examples/set_matching/seq2seq/seq2seq_setmatching.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import logging 4 | import os 5 | import sys 6 | import time 7 | 8 | import matplotlib.pyplot as plt 9 | import numpy as np 10 | import seaborn as sns 11 | import torch 12 | from fdsa.models.set_matching.seq2seq import Seq2Seq 13 | from fdsa.utils.hyperparameters import LR_SCHEDULER_FACTORY 14 | from fdsa.utils.loss_setmatching import SetMatchLoss 15 | from pytoda.datasets.distributional_dataset import DistributionalDataset 16 | from pytoda.datasets.set_matching_dataset import ( 17 | PairedSetMatchingDataset, 18 | PermutedSetMatchingDataset, 19 | ) 20 | from pytoda.datasets.utils.factories import ( 21 | DISTRIBUTION_FUNCTION_FACTORY, 22 | METRIC_FUNCTION_FACTORY, 23 | ) 24 | from torch.utils.data import DataLoader 25 | 26 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 27 | 28 | parser = argparse.ArgumentParser() 29 | parser.add_argument("model_path", type=str, help="Path to save best model.") 30 | parser.add_argument( 31 | "results_path", type=str, help="Path to save results and plots." 32 | ) 33 | parser.add_argument( 34 | "training_params", type=str, help="Path to the training parameters." 35 | ) 36 | 37 | 38 | def main(model_path, results_path, training_params): 39 | 40 | torch.Tensor.ndim = property(lambda self: len(self.shape)) 41 | # setup logging 42 | logging.basicConfig( 43 | handlers=[ 44 | logging.FileHandler( 45 | os.path.join(results_path, "set_matchingSeq2Seq.log") 46 | ), 47 | logging.StreamHandler(sys.stdout), 48 | ], 49 | ) 50 | logger = logging.getLogger("sets") 51 | logger.setLevel(logging.DEBUG) 52 | 53 | # setup plots 54 | clrs = sns.color_palette("Paired", 8) 55 | fig, ax = plt.subplots(constrained_layout=True) 56 | ax.set_title("Train and Validation Loss of Sets Matching") 57 | ax.set_xlabel("Epochs") 58 | ax.set_ylabel("Loss") 59 | 60 | fig2, (ax2, ax3) = plt.subplots(2, 1, sharex=True, constrained_layout=True) 61 | ax2.set_title( 62 | "Train and Validation Accuracy of Various Recurrent Cells on Set Matching Task" 63 | ) 64 | ax3.set(xlabel="Epochs", ylabel="Proportion of Samples with 100% Accuracy") 65 | ax2.set(ylabel="Average per Set Accuracy") 66 | 67 | with open(training_params, "r") as readjson: 68 | params = json.load(readjson) 69 | 70 | dataset_dict = { 71 | "permuted": PermutedSetMatchingDataset, 72 | "sampled": PairedSetMatchingDataset, 73 | } 74 | 75 | train_size = params.get("train_size", 20000) 76 | valid_size = params.get("valid_size", 5000) 77 | test_size = params.get("test_size", 20000) 78 | 79 | input_dim = params.get("input_size", 5) 80 | max_length = params.get("max_length", 5) 81 | min_length = params.get("min_length", 2) 82 | padding_value = params.get("padding_value", 4.0) 83 | noise_std = params.get("noise_std", 0.0) 84 | 85 | dist_type = params.get("distribution_type", "normal") 86 | dist_args = params.get("distribution_args", {"loc": 0.0, "scale": 1.0}) 87 | tensor_args = {} 88 | for key, value in dist_args.items(): 89 | if key == "covariance_matrix": 90 | value = torch.tensor(value, device=device) 91 | identity_matrix = torch.eye(len(value), device=device) 92 | identity_matrix[range(len(value)), range(len(value))] = value 93 | tensor_args[key] = identity_matrix 94 | else: 95 | tensor_args[key] = torch.tensor(value, device=device) 96 | 97 | metric = params.get("cost_metric", "p-norm") 98 | metric_args = params.get("cost_metric_args", {"p": 2}) 99 | 100 | seeds = params["seeds"] 101 | for k, i in seeds.items(): 102 | seeds[k] = eval(i) 103 | 104 | dataset_type = params.get("dataset_type", "permute") 105 | batch_first = eval(params["batch_first"]) 106 | 107 | dist_function = DISTRIBUTION_FUNCTION_FACTORY[dist_type](**tensor_args) 108 | cost_function = METRIC_FUNCTION_FACTORY[metric](**metric_args) 109 | 110 | # setup datasets and dataloader 111 | train, valid, test = [], [], [] 112 | 113 | dataset_train_1 = DistributionalDataset( 114 | train_size, (max_length, input_dim), dist_function, seed=seeds["train1"] 115 | ) 116 | 117 | dataset_valid_1 = DistributionalDataset( 118 | valid_size, (max_length, input_dim), dist_function, seed=seeds["valid1"] 119 | ) 120 | 121 | dataset_test_1 = DistributionalDataset( 122 | test_size, (max_length, input_dim), dist_function, seed=seeds["test1"] 123 | ) 124 | 125 | train.append(dataset_train_1) 126 | valid.append(dataset_valid_1) 127 | test.append(dataset_test_1) 128 | 129 | if dataset_type == "sampled": 130 | dataset_train_2 = DistributionalDataset( 131 | train_size, 132 | (max_length, input_dim), 133 | dist_function, 134 | seed=seeds["train2"], 135 | ) 136 | dataset_valid_2 = DistributionalDataset( 137 | valid_size, 138 | (max_length, input_dim), 139 | dist_function, 140 | seed=seeds["valid2"], 141 | ) 142 | dataset_test_2 = DistributionalDataset( 143 | test_size, 144 | (max_length, input_dim), 145 | dist_function, 146 | seed=seeds["test2"], 147 | ) 148 | train.append(dataset_train_2) 149 | valid.append(dataset_valid_2) 150 | test.append(dataset_test_2) 151 | 152 | dataset_train = dataset_dict[dataset_type]( 153 | *train, 154 | min_length, 155 | cost_function, 156 | padding_value, 157 | seed=seeds["train_true"], 158 | noise_std=noise_std, 159 | ) 160 | 161 | dataset_test = dataset_dict[dataset_type]( 162 | *test, 163 | min_length, 164 | cost_function, 165 | padding_value, 166 | seed=seeds["test_true"], 167 | noise_std=noise_std, 168 | ) 169 | 170 | dataset_valid = dataset_dict[dataset_type]( 171 | *valid, 172 | min_length, 173 | cost_function, 174 | padding_value, 175 | seed=seeds["valid_true"], 176 | noise_std=noise_std, 177 | ) 178 | 179 | train_loader = DataLoader( 180 | dataset_train, batch_size=params["train_batch"], shuffle=True 181 | ) 182 | 183 | valid_loader = DataLoader( 184 | dataset_valid, batch_size=params["valid_batch"], shuffle=True 185 | ) 186 | 187 | test_loader = DataLoader( 188 | dataset_test, batch_size=params["test_batch"], shuffle=True 189 | ) 190 | 191 | epochs = params["epochs"] 192 | 193 | connector = torch.zeros(1, input_dim).fill_(99.0) 194 | 195 | clr_idx = 0 196 | 197 | patience = params.get("patience", 5) 198 | epochs_no_change = 0 199 | 200 | for cell in ["GRU", "nBRC", "LSTM", "BRC"]: 201 | 202 | params.update({"cell": cell}) 203 | 204 | # setup model params 205 | model = None 206 | if model is not None: 207 | del model 208 | save_here = os.path.join( 209 | model_path, "{}_{}".format(params["loss"], cell) 210 | ) 211 | 212 | # setup model and optimiser 213 | model = Seq2Seq(params, device).to(device) 214 | 215 | optimzr = torch.optim.Adam(model.parameters(), lr=params["lr"]) 216 | # lr_scheduler = LR_SCHEDULER_FACTORY[params['scheduler'] 217 | # ](optimiser, *params['lr_args'],verbose=True) 218 | 219 | min_loss = np.inf 220 | 221 | loss_fn = SetMatchLoss( 222 | params["loss"], 223 | params["ce_type"], 224 | params["temperature"], 225 | params["sinkhorn_iter"], 226 | ) 227 | 228 | plot_train = [] 229 | plot_valid = [] 230 | set_acc_train = [] 231 | set_acc_valid = [] 232 | acc100_train = [] 233 | acc100_valid = [] 234 | 235 | for epoch in range(epochs): 236 | 237 | model.train() 238 | logger.info("=== {} Epoch [{}/{}]".format(cell, epoch + 1, epochs)) 239 | 240 | for idx, ( 241 | x1_train, 242 | x2_train, 243 | targets12_train, 244 | targets21_train, 245 | lens_train, 246 | ) in enumerate(train_loader): 247 | 248 | bs_train = len(targets12_train) 249 | connector_ = connector.expand(bs_train, -1, -1).to(device) 250 | 251 | x1_train, x2_train = x1_train.to(device), x2_train.to(device) 252 | train_x = torch.cat((x1_train, connector_, x2_train), dim=1).to( 253 | device 254 | ) 255 | 256 | targets12_train, targets21_train = ( 257 | targets12_train.to(device), 258 | targets21_train.to(device), 259 | ) 260 | 261 | if batch_first is False: 262 | # permute so batch comes second 263 | train_x = train_x.permute(1, 0, 2) 264 | x1_train = x1_train.permute(1, 0, 2) 265 | 266 | train_output, train_attn = model(train_x, x1_train) 267 | 268 | if batch_first is False: 269 | # permute output so batch comes first for loss calculation 270 | train_output = train_output.permute(1, 0, 2) 271 | 272 | train_loss = loss_fn( 273 | train_output, targets21_train, targets12_train 274 | ) 275 | 276 | optimzr.zero_grad() 277 | train_loss.backward() 278 | optimzr.step() 279 | 280 | if idx == (len(train_loader) - 2): 281 | 282 | # Replace 2 with 1 and targets12_train with targets21_train 283 | # for column-wise objective functions such as KL-Div_col. 284 | 285 | train_preds = torch.argmax(train_output.detach(), 2) 286 | all_correct = ( 287 | (train_preds == targets12_train) 288 | .all(dim=1) 289 | .float() 290 | .sum() 291 | ) 292 | correct_train = ( 293 | (train_preds == targets12_train).float().sum() 294 | ) 295 | 296 | plot_train.append(train_loss.detach()) 297 | logger.info("Train Loss = {}".format(train_loss.detach())) 298 | 299 | batch_avg_acc = correct_train / (max_length * bs_train) 300 | set_acc_train.append(batch_avg_acc) 301 | logger.info("Train per set Acc = {}".format(batch_avg_acc)) 302 | 303 | acc_train = all_correct / bs_train 304 | acc100_train.append(acc_train) 305 | logger.info("Train 100% Acc = {}".format(acc_train)) 306 | 307 | model.eval() 308 | avg_valid_loss = 0 309 | all_correct_valid = 0 310 | accuracy_valid = [] 311 | 312 | for idx, ( 313 | x1_valid, 314 | x2_valid, 315 | targets12_valid, 316 | targets21_valid, 317 | lens_valid, 318 | ) in enumerate(valid_loader): 319 | 320 | bs_valid = len(targets12_valid) 321 | connector_ = connector.expand(bs_valid, -1, -1).to(device) 322 | 323 | x1_valid, x2_valid = x1_valid.to(device), x2_valid.to(device) 324 | valid_x = torch.cat((x1_valid, connector_, x2_valid), dim=1).to( 325 | device 326 | ) 327 | 328 | targets12_valid, targets21_valid = ( 329 | targets12_valid.to(device), 330 | targets21_valid.to(device), 331 | ) 332 | 333 | if batch_first is False: 334 | # permute so batch comes second 335 | valid_x = valid_x.permute(1, 0, 2) 336 | x1_valid = x1_valid.permute(1, 0, 2) 337 | 338 | valid_output, valid_attn = model(valid_x, x1_valid) 339 | 340 | if batch_first is False: 341 | # permute output so batch comes first for loss calculation 342 | valid_output = valid_output.permute(1, 0, 2) 343 | 344 | valid_loss = loss_fn( 345 | valid_output, targets21_valid, targets12_valid 346 | ) 347 | 348 | avg_valid_loss = ( 349 | avg_valid_loss * idx + valid_loss.detach() 350 | ) / (idx + 1) 351 | 352 | # Replace 2 with 1 and target12_valid with target21_valid 353 | # for column-wise objective functions such as KL-Div_col. 354 | 355 | valid_preds = torch.argmax(valid_output.detach(), 2) 356 | all_correct_valid += ( 357 | (valid_preds == targets12_valid).all(dim=1).float().sum() 358 | ) 359 | correct_valid = (valid_preds == targets12_valid).float().sum() 360 | 361 | accuracy_valid.append(correct_valid / (max_length * bs_valid)) 362 | 363 | plot_valid.append(avg_valid_loss.detach()) 364 | logger.info("Avg Valid Loss = {}".format(avg_valid_loss.detach())) 365 | 366 | avg_valid_acc = torch.mean(torch.as_tensor(accuracy_valid)) 367 | set_acc_valid.append(avg_valid_acc) 368 | logger.info("Valid per set Acc = {}".format(avg_valid_acc)) 369 | 370 | acc_valid = all_correct_valid / valid_size 371 | acc100_valid.append(acc_valid) 372 | logger.info("Valid 100% Acc = {}".format(acc_valid)) 373 | 374 | if avg_valid_loss < min_loss: 375 | min_loss = avg_valid_loss 376 | epochs_no_change = 0 377 | torch.save( 378 | { 379 | "epoch": epoch, 380 | "model_state_dict": model.state_dict(), 381 | "optimizer_state_dict": optimzr.state_dict(), 382 | "train_loss": train_loss, 383 | "valid_loss": valid_loss, 384 | }, 385 | save_here, 386 | ) 387 | else: 388 | if epoch >= 19: 389 | epochs_no_change += 1 390 | 391 | if epochs_no_change == patience: 392 | logger.info("Early Stopping at {} epochs".format(epoch)) 393 | break 394 | 395 | # lr_scheduler.step(avg_valid_loss.detach()) 396 | 397 | torch.save( 398 | plot_train, os.path.join(results_path, "train_loss_{}".format(cell)) 399 | ) 400 | torch.save( 401 | plot_valid, 402 | os.path.join(results_path, "avg_valid_loss_{}".format(cell)), 403 | ) 404 | ax.plot( 405 | range(len(plot_train)), 406 | plot_train, 407 | color=clrs[clr_idx * 2], 408 | label="{} Training Loss".format(cell), 409 | ) 410 | ax.plot( 411 | range(len(plot_valid)), 412 | plot_valid, 413 | color=clrs[clr_idx * 2 + 1], 414 | label="{} Validation Loss".format(cell), 415 | ) 416 | 417 | ax2.plot( 418 | range(len(set_acc_train)), 419 | set_acc_train, 420 | color=clrs[clr_idx * 2], 421 | label="{} Batch Training Accuracy".format(cell), 422 | ) 423 | ax2.plot( 424 | range(len(set_acc_valid)), 425 | set_acc_valid, 426 | color=clrs[clr_idx * 2 + 1], 427 | label="{} Validation Accuracy".format(cell), 428 | ) 429 | 430 | ax3.plot( 431 | range(len(acc100_train)), 432 | acc100_train, 433 | color=clrs[clr_idx * 2], 434 | label="{} Training Accuracy".format(cell), 435 | ) 436 | ax3.plot( 437 | range(len(acc100_valid)), 438 | acc100_valid, 439 | color=clrs[clr_idx * 2 + 1], 440 | label="{} Validation Accuracy".format(cell), 441 | ) 442 | 443 | avg_test_loss = 0 444 | all_correct_test = 0 445 | model.eval() 446 | 447 | test_predlist = [] 448 | test12_truelist = [] 449 | test21_truelist = [] 450 | accuracy_test = [] 451 | test_lengths = [] 452 | 453 | for idx, ( 454 | x1_test, 455 | x2_test, 456 | targets12_test, 457 | targets21_test, 458 | lens_test, 459 | ) in enumerate(test_loader): 460 | 461 | bs_test = len(targets12_test) 462 | connector_ = connector.expand(bs_test, -1, -1).to(device) 463 | x1_test, x2_test = x1_test.to(device), x2_test.to(device) 464 | test_x = torch.cat((x1_test, connector_, x2_test), dim=1).to(device) 465 | 466 | targets12_test, targets21_test = ( 467 | targets12_test.to(device), 468 | targets21_test.to(device), 469 | ) 470 | 471 | if batch_first is False: 472 | test_x = test_x.permute(1, 0, 2) 473 | x1_test = x1_test.permute(1, 0, 2) 474 | 475 | test_output, test_attn = model(test_x, x1_test) 476 | 477 | if batch_first is False: 478 | # permute output so batch comes first for loss calculation 479 | test_output = test_output.permute(1, 0, 2) 480 | 481 | test_loss = loss_fn(test_output, targets21_test, targets12_test) 482 | 483 | test_predlist.append(test_output.detach()) 484 | test21_truelist.append(targets21_test.detach()) 485 | test12_truelist.append(targets12_test.detach()) 486 | test_lengths.append(lens_test) 487 | 488 | # Replace 2 with 1 and target12_test with target21_test 489 | # for column-wise objective functions such as KL-Div_col. 490 | 491 | test_preds = torch.argmax(test_output, 2) 492 | all_correct_test += ( 493 | (test_preds == targets12_test).all(dim=1).float().sum() 494 | ) 495 | correct_test = (test_preds == targets12_test).float().sum() 496 | 497 | accuracy_test.append(correct_test / (max_length * bs_test)) 498 | 499 | avg_test_loss = (avg_test_loss * idx + test_loss.detach()) / ( 500 | idx + 1 501 | ) 502 | logger.info("Avg Test Loss = {}".format(avg_test_loss.detach())) 503 | 504 | logger.info( 505 | "Avg Test per set Acc = {}".format( 506 | torch.mean(torch.as_tensor(accuracy_test)) 507 | ) 508 | ) 509 | logger.info("Avg Test Acc = {}".format(all_correct_test / test_size)) 510 | 511 | torch.save( 512 | test12_truelist, 513 | os.path.join(results_path, "true_match12_{}".format(cell)), 514 | ) 515 | torch.save( 516 | test21_truelist, 517 | os.path.join(results_path, "true_match21_{}".format(cell)), 518 | ) 519 | torch.save( 520 | test_predlist, 521 | os.path.join(results_path, "pred_match_{}".format(cell)), 522 | ) 523 | torch.save( 524 | test_lengths, 525 | os.path.join(results_path, "test_lengths_{}".format(cell)), 526 | ) 527 | clr_idx += 1 528 | 529 | ax.legend() 530 | fig.savefig( 531 | os.path.join(results_path, "setmatchloss_{}.png".format(params["loss"])) 532 | ) 533 | 534 | ax2.legend(loc="upper right", bbox_to_anchor=(1.2, 1)) 535 | fig2.savefig( 536 | os.path.join(results_path, "setmatchacc_{}.png".format(params["loss"])) 537 | ) 538 | 539 | 540 | if __name__ == "__main__": 541 | args = parser.parse_args() 542 | main(args.model_path, args.results_path, args.training_params) 543 | -------------------------------------------------------------------------------- /fdsa/__init__.py: -------------------------------------------------------------------------------- 1 | name = 'fdsa' 2 | __version__ = '0.0.1' 3 | -------------------------------------------------------------------------------- /fdsa/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PaccMann/fdsa/0e592e5281df69b3da2ea004ad3d96d9ca286f4d/fdsa/datasets/__init__.py -------------------------------------------------------------------------------- /fdsa/datasets/galaxy_data.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import pandas as pd 4 | from astropy.table import Table 5 | 6 | 7 | class Galaxy: 8 | """Generates a dataset of galaxy clusters. 9 | """ 10 | 11 | def __init__(self, fits_filepath): 12 | """Constructor. 13 | 14 | Args: 15 | fits_filepath (string): Path to the fits file. 16 | """ 17 | self.path = fits_filepath 18 | 19 | @staticmethod 20 | def save_csv(data, filepath): 21 | """Save dataframe as a csv file. 22 | 23 | Args: 24 | data (dataframe): Dataframe to be saved as csv. 25 | filepath (string): Path where the csv file is to be saved. 26 | """ 27 | 28 | assert type(data) == pd.DataFrame 29 | 30 | data.to_csv(filepath) 31 | 32 | def data_galaxy(self) -> Tuple: 33 | """Extracts data from fits file. 34 | 35 | Returns: 36 | list_of_dataframes (List): Each element in the list is a N(c) x 19 37 | matrix corresponding to a cluster c that describes the galaxies 38 | in that cluster. N(c) is the number of galaxies in cluster c. 39 | The 19 columns represent 2 IDs and 17 features. 40 | list_of_targets (List): Each element in the list is a N(c) x 3 41 | matrix corresponding to a cluster c. The 3 columns represent 42 | cluster ID, object ID, and the true spectrometric value. 43 | """ 44 | 45 | data = Table.read(self.path, format='fits') 46 | dataframe = data.to_pandas() 47 | 48 | return dataframe 49 | -------------------------------------------------------------------------------- /fdsa/datasets/shapes_data.py: -------------------------------------------------------------------------------- 1 | from math import cos, pi, sin 2 | from typing import List, Tuple 3 | 4 | import numpy as np 5 | import pandas as pd 6 | from skimage import draw 7 | 8 | 9 | class Shapes: 10 | """Generates a dataset of shapes with varying sizes and rotations. 11 | """ 12 | 13 | def __init__( 14 | self, 15 | min_x: float = -0.5, 16 | max_x: float = 0.5, 17 | min_y: float = -0.5, 18 | max_y: float = 0.5, 19 | min_boundary: int = 1000, 20 | max_boundary: int = 9000, 21 | img_size: int = 10000, 22 | seed: int = None 23 | ): 24 | """Constructor. 25 | 26 | Args: 27 | min_x (float, optional): Lower boundary of the x axis. 28 | Defaults to -0.5. 29 | max_x (float, optional): Upper boundary of the x axis. 30 | Defaults to 0.5. 31 | min_y (float, optional): Lower boundary of the x axis. 32 | Defaults to -0.5. 33 | max_y (float, optional): Upper boundary of the y axis. 34 | Defaults to 0.5. 35 | min_boundary (int, optional): Lower boundary of the sampling area 36 | for x and y axes. Defaults to 1000. 37 | max_boundary (int, optional): Upper boundary of the sampling area 38 | for x and y axes. Defaults to 9000. 39 | img_size (int, optional): Image resolution. Defaults to 10000. 40 | 41 | NOTE: The min and max boundaries of x and y axes represent the 42 | region where the shapes are required to be situated. 43 | The sampling area boundaries define regions from which the 44 | centres of the shapes are sampled to ensure the above min/max 45 | boundaries are respected. 46 | The shape is first generated as an image; Image resolution 47 | controls how accurately the points are translated to the 48 | required region. 49 | """ 50 | 51 | self.max_x = max_x 52 | self.min_x = min_x 53 | self.max_y = max_y 54 | self.min_y = min_y 55 | self.min_boundary = min_boundary 56 | self.max_boundary = max_boundary 57 | self.img_size = img_size 58 | if seed: 59 | np.random.seed(42) 60 | 61 | @staticmethod 62 | def save_csv(data: pd.DataFrame, filepath: str) -> None: 63 | """Save dataframe as a csv file. 64 | 65 | Args: 66 | data (dataframe): Dataframe to be saved as csv. 67 | filepath (string): Path where the csv file is to be saved. 68 | """ 69 | 70 | assert type(data) == pd.DataFrame 71 | 72 | data.to_csv(filepath) 73 | 74 | def generate_random(self, dim: int, sample_id: int) -> np.ndarray: 75 | """Generates a set of data points that are sampled uniformly at random. 76 | Args: 77 | dim (int): Dimension of the data. 78 | sample_id(int): ID of the set. 79 | Returns: 80 | np.ndarray(float64) : dim-D array of the coordinates 81 | """ 82 | k = np.random.randint(1, 17) 83 | element_size = tuple([k, dim]) 84 | column_names = [['x'], ['x', 'y'], ['x', 'y', 'z']] 85 | 86 | label_id = pd.DataFrame({'label': [dim] * k, 'ID': [sample_id] * k}) 87 | 88 | return pd.concat( 89 | [ 90 | pd.DataFrame( 91 | np.random.uniform(self.min_x, self.max_x, size=element_size), 92 | columns=column_names[dim - 1] 93 | ), label_id 94 | ], 95 | axis=1 96 | ) 97 | 98 | def data_random(self, dimension: int, sample_size: int) -> np.ndarray: 99 | """Generates a random uniformly distributed dataset. 100 | 101 | Args: 102 | dimension (int): dimension of elements of the dataset. 103 | sample_size (int): Length of the dataset. 104 | 105 | Returns: 106 | dataset (np.ndarray(float64)): array of dimension-D sets with 107 | varying lengths. 108 | """ 109 | dataset = pd.concat( 110 | [self.generate_random(dimension, i) for i in range(1, sample_size + 1)], 111 | sort=False 112 | ) 113 | 114 | return dataset 115 | 116 | def get_max_radius(self, x: int, y: int, min_radius: int) -> int: 117 | """Generates a radius for the bounding circle based on maximum and 118 | minimum possible radii. 119 | 120 | Args: 121 | x (int): x coordinate of the centre pixel of the circle. 122 | y (int): y coordinate of the centre pixel of the circle. 123 | min_radius (int): Fixed minimum radius of the 124 | bounding circle in pixels. 125 | 126 | Returns: 127 | radius (int): The radius chosen at random from a calculated range. 128 | """ 129 | 130 | r_max_x = min(np.sqrt((0 - x)**2), np.sqrt((self.img_size - x)**2)) 131 | r_max_y = min(np.sqrt((0 - y)**2), np.sqrt((self.img_size - y)**2)) 132 | 133 | r_lim = min(r_max_x, r_max_y) 134 | 135 | radius = np.random.randint(min_radius, int(r_lim)) 136 | 137 | return radius 138 | 139 | def scale(self, value, old_min, old_max, new_min, new_max) -> np.float64: 140 | """Translates a given coordinate from one scale range to another. 141 | 142 | Args: 143 | value (int or float): The 1-D coordinate to be translated. 144 | old_min (int or float): The minimum value of the old range. 145 | old_max (int or float): The maximum value of the old range. 146 | new_min (int or float): The minimum value of the new range. 147 | new_max (int or float): The maximum value of the new range. 148 | 149 | Returns: 150 | (int or float): Value of the coordinate in the new range. 151 | """ 152 | return ((value - old_min) * 153 | (new_max - new_min)) / ((old_max - old_min)) + new_min 154 | 155 | def datapoints_circle( 156 | self, set_length: int, sample_id: int, min_radius: int = 100, use: str = None 157 | ): 158 | """Generates a set of datapoints sampled from the circumference of a 159 | circle. 160 | 161 | Args: 162 | set_length (int): The number of elements required in the set. 163 | sample_id(int): ID of the set. 164 | min_radius (int, optional): The minimum radius of the 165 | bounding circle in pixels. Defaults to 100. 166 | use (string, optional): "square" or "cross" signifies the shape it 167 | is used for as a bounding circle. Defaults to None. 168 | 169 | Returns: 170 | datapoints (np.ndarray): Sampled 2D points from the circumference 171 | of the circle of dtype int if used as a bounding circle, 172 | float otherwise. 173 | radius (int): Radius of the circle. 174 | x (int): The x coordinate of the centre pixel of the circle. 175 | Returned only when use = 'cross'. 176 | y (int): The y coordinate of the centre pixel of the circle. 177 | Returned only when use = 'cross'. 178 | """ 179 | 180 | x = np.random.randint(self.min_boundary, self.max_boundary) 181 | y = np.random.randint(self.min_boundary, self.max_boundary) 182 | 183 | radius = self.get_max_radius(x, y, min_radius) 184 | 185 | rr, cc = draw.circle_perimeter( 186 | x, y, radius, shape=(self.img_size, self.img_size) 187 | ) 188 | 189 | if use == 'square': 190 | return x, y, radius 191 | 192 | elif use == 'cross': 193 | 194 | # uncomment the following lines to add rotations 195 | 196 | # points = np.array(list(zip(rr, cc))) 197 | 198 | # sampling_indices = np.random.randint(0, len(points), set_length) 199 | # datapoints = points[sampling_indices] 200 | 201 | # comment the following lines to add rotations 202 | 203 | rr0 = int(x + radius * cos(pi / 2)) 204 | cc0 = int(y + radius * sin(pi / 2)) 205 | datapoints = np.array(list(zip([rr0], [cc0]))) 206 | 207 | return datapoints, x, y 208 | 209 | else: 210 | scaled_rr = self.scale(rr, 0, self.img_size, self.min_x, self.max_x) 211 | scaled_cc = self.scale(cc, 0, self.img_size, self.min_y, self.max_y) 212 | 213 | assert all(np.abs(scaled_rr) <= self.max_x) 214 | assert all(np.abs(scaled_cc) <= self.max_y) 215 | 216 | points = pd.DataFrame( 217 | { 218 | 'x': scaled_rr, 219 | 'y': scaled_cc, 220 | 'label': [0] * len(rr), 221 | 'ID': [sample_id] * len(rr) 222 | } 223 | ) 224 | 225 | sampling_indices = np.random.randint(0, len(points), set_length) 226 | datapoints = points.iloc[sampling_indices, :].reset_index(drop=True) 227 | 228 | return datapoints 229 | 230 | def datapoints_square( 231 | self, set_length: int, sample_id: int, min_radius: int = 100 232 | ) -> Tuple: 233 | """Generates a set of datapoints sampled from the perimeter of a 234 | square. 235 | 236 | Args: 237 | set_length (int): The number of elements required in the set. 238 | sample_id(int): ID of the set. 239 | min_radius (int, optional): Fixed minimum radius of the bounding 240 | circle in pixels. Defaults to 100. 241 | 242 | Returns: 243 | datapoints (np.ndarray(float)): Sampled 2D points from the 244 | perimeter of the square. 245 | side (float): Length of the square generated. 246 | """ 247 | 248 | cx, cy, radius = self.datapoints_circle(1, sample_id, use='square') 249 | 250 | side = np.sqrt(radius * radius * 2) 251 | half_side = side * 0.5 252 | 253 | # uncomment the following line to add rotations 254 | 255 | # angle = np.random.random_sample() * pi / 2 256 | 257 | r, c = draw.rectangle_perimeter( 258 | (cx - half_side, cy - half_side), 259 | extent=(side, side), 260 | shape=(self.img_size, self.img_size) 261 | ) 262 | 263 | scaled_r = self.scale(r, 0, self.img_size, self.min_x, self.max_x) 264 | scaled_c = self.scale(c, 0, self.img_size, self.min_y, self.max_y) 265 | 266 | # uncomment the following lines to add rotations 267 | 268 | # mean_r = (max(scaled_r) + min(scaled_r)) / 2 269 | # mean_c = (max(scaled_c) + min(scaled_c)) / 2 270 | 271 | # rot_r = (scaled_r - mean_r) * cos(angle) - (scaled_c - mean_c 272 | # ) * sin(angle) + mean_r 273 | # rot_c = (scaled_r - mean_r) * sin(angle) + (scaled_c - mean_c 274 | # ) * cos(angle) + mean_c 275 | 276 | # assert all(np.abs(rot_r) <= self.max_x) 277 | # assert all(np.abs(rot_c) <= self.max_y) 278 | 279 | # points = pd.DataFrame( 280 | # { 281 | # 'x': rot_r, 282 | # 'y': rot_c, 283 | # 'label': [1] * len(rot_r), 284 | # 'ID': [sample_id] * len(rot_r) 285 | # } 286 | # ) 287 | 288 | # comment the following lines to add rotations 289 | 290 | points = pd.DataFrame( 291 | { 292 | 'x': scaled_r, 293 | 'y': scaled_c, 294 | 'label': [1] * len(scaled_r), 295 | 'ID': [sample_id] * len(scaled_r) 296 | } 297 | ) 298 | 299 | sampling_indices = np.random.randint(0, len(points), set_length) 300 | datapoints = points.iloc[sampling_indices, :].reset_index(drop=True) 301 | 302 | return datapoints 303 | 304 | def datapoints_cross( 305 | self, set_length: int, sample_id: int, min_radius: int = 100 306 | ) -> Tuple: 307 | """Generates a set of datapoints sampled from a cross. 308 | 309 | Args: 310 | set_length (int): The number of elements required in the set. 311 | sample_id(int): ID of the set. 312 | min_radius (int, optional): Fixed minimum radius of the bounding 313 | circle in pixels. Defaults to 100. 314 | 315 | Returns: 316 | datapoints (np.ndarray(float)): Sampled 2D points from the cross. 317 | radius (int): Radius of the bounding circle in pixels. 318 | """ 319 | 320 | start_point, centre_x, centre_y = self.datapoints_circle( 321 | 1, sample_id, use='cross' 322 | ) 323 | 324 | x1 = 2 * centre_x - start_point[0, 0] 325 | y1 = 2 * centre_y - start_point[0, 1] 326 | 327 | assert x1.dtype == int 328 | assert y1.dtype == int 329 | 330 | rr, cc = draw.line(start_point[0, 0], start_point[0, 1], x1, y1) 331 | 332 | scaled_rr = self.scale(rr, 0, self.img_size, self.min_x, self.max_x) 333 | scaled_cc = self.scale(cc, 0, self.img_size, self.min_y, self.max_y) 334 | 335 | assert all(np.abs(scaled_rr) <= self.max_x) 336 | assert all(np.abs(scaled_cc) <= self.max_y) 337 | 338 | mean_rr = (max(scaled_rr) + min(scaled_rr)) / 2 339 | mean_cc = (max(scaled_cc) + min(scaled_cc)) / 2 340 | 341 | rot_rr = ( 342 | (scaled_rr - mean_rr) * cos(pi / 2) - (scaled_cc - mean_cc) * sin(pi / 2) 343 | ) + mean_rr 344 | rot_cc = ( 345 | (scaled_rr - mean_rr) * sin(pi / 2) + (scaled_cc - mean_cc) * cos(pi / 2) 346 | ) + mean_cc 347 | 348 | # uncomment the following lines to add rotations 349 | 350 | # assert all(np.abs(rot_rr) <= self.max_x) 351 | # assert all(np.abs(rot_cc) <= self.max_y) 352 | 353 | # points = pd.DataFrame( 354 | # { 355 | # 'x': np.append(scaled_rr, rot_rr), 356 | # 'y': np.append(scaled_cc, rot_cc), 357 | # 'label': [2] * (2 * len(rr)), 358 | # 'ID': [sample_id] * (2 * len(rr)) 359 | # } 360 | # ) 361 | 362 | # comment the following lines to add rotations 363 | 364 | points = pd.DataFrame( 365 | { 366 | 'x': np.append(scaled_rr, rot_rr), 367 | 'y': np.append(scaled_cc, rot_cc), 368 | 'label': [2] * (2 * len(rr)), 369 | 'ID': [sample_id] * (2 * len(rr)) 370 | } 371 | ) 372 | 373 | sampling_indices = np.random.randint(0, len(points), set_length) 374 | datapoints = points.iloc[sampling_indices, :].reset_index(drop=True) 375 | 376 | return datapoints 377 | 378 | def generate_shapes(self, shapes_list: List, sample_id: int) -> np.array: 379 | """Generates a shape at random. 380 | 381 | Args: 382 | shapes_list (list[objects]): List of function names for generating 383 | different shapes. 384 | sample_id(int): ID of the set. 385 | 386 | Returns: 387 | np.ndarray: 2D array of sampled datapoints from the randomly 388 | chosen shape. 389 | """ 390 | set_length = np.random.randint(10, 34) 391 | 392 | return np.random.choice(shapes_list)(set_length, sample_id) 393 | 394 | def data_shapes(self, sample_size: int) -> List: 395 | """Generates a dataset of randomly chosen shapes of varying sizes and 396 | orientations. 397 | 398 | Args: 399 | sample_size (int): Length of the dataset. 400 | 401 | Returns: 402 | dataset (list[float]): Each item is a shape and its coordinates. 403 | 404 | NOTE: dataset is an array of 2D variable length arrays. 405 | """ 406 | 407 | shapes_list = [ 408 | self.datapoints_circle, self.datapoints_square, self.datapoints_cross 409 | ] 410 | 411 | dataset = pd.concat( 412 | [self.generate_shapes(shapes_list, i) for i in range(1, sample_size + 1)], 413 | sort=False 414 | ) 415 | 416 | return dataset 417 | -------------------------------------------------------------------------------- /fdsa/datasets/tests/test_galaxy_data.py: -------------------------------------------------------------------------------- 1 | """Testing Galaxy Data""" 2 | import os 3 | import shutil 4 | import tempfile 5 | 6 | import pandas as pd 7 | import requests 8 | from fdsa.datasets.galaxy_data import Galaxy 9 | 10 | 11 | def test_data_galaxy(): 12 | """Test data_galaxy.""" 13 | 14 | directory = tempfile.mkdtemp() 15 | 16 | url = ( 17 | 'http://risa.stanford.edu/redmapper/v6.3/' + 18 | 'redmapper_dr8_public_v6.3_members.fits.gz' 19 | ) 20 | 21 | filename = url.split("/")[-1] 22 | filepath = os.path.join(directory, filename) 23 | with open(filepath, "wb") as f: 24 | r = requests.get(url) 25 | f.write(r.content) 26 | 27 | data = Galaxy(filepath).data_galaxy() 28 | 29 | assert type(data) == pd.DataFrame 30 | assert len(pd.unique(data['ID'])) == 26111 31 | assert data.shape[1] == 22 32 | 33 | shutil.rmtree(directory) 34 | -------------------------------------------------------------------------------- /fdsa/datasets/tests/test_shapes_data.py: -------------------------------------------------------------------------------- 1 | """Testing Shapes""" 2 | import numpy as np 3 | import pandas as pd 4 | from fdsa.datasets.shapes_data import Shapes 5 | 6 | 7 | def test_data_random(): 8 | """Test data_random. """ 9 | sample_size = 10 10 | dim = 3 11 | column_names = [['x'], ['x', 'y'], ['x', 'y', 'z']] 12 | data = Shapes().data_random(dim, sample_size) 13 | 14 | assert type(data) == pd.DataFrame 15 | assert len(pd.unique(data['ID'])) == sample_size 16 | assert data.shape[1] == (dim + 2) 17 | assert data['label'].dtype == int 18 | assert data['ID'].dtype == int 19 | 20 | for i in column_names[dim - 1]: 21 | assert data[i].dtype == np.float64 22 | 23 | 24 | def test_datapoints_circle(): 25 | """Test datapoints_circle.""" 26 | sample_size = 10 27 | sample_id = 1 28 | 29 | data = Shapes().datapoints_circle(sample_size, sample_id) 30 | 31 | assert type(data) == pd.DataFrame 32 | assert len(data) == sample_size 33 | assert data.shape[1] == 4 34 | assert data['x'].dtype == np.float64 35 | assert data['y'].dtype == np.float64 36 | assert data['label'].dtype == int 37 | assert data['ID'].dtype == int 38 | 39 | cx, cy, radius = Shapes().datapoints_circle(1, 1, use='square') 40 | 41 | assert type(cx) == int 42 | assert type(cy) == int 43 | assert type(radius) == int 44 | 45 | cross_start, centre_x, centre_y = Shapes().datapoints_circle(1, 1, use='cross') 46 | 47 | assert len(cross_start) == 1 48 | assert cross_start.dtype == int 49 | assert type(centre_x) == int 50 | assert type(centre_y) == int 51 | 52 | 53 | def test_datapoints_square(): 54 | """Test datapoints_square.""" 55 | sample_size = 10 56 | sample_id = 1 57 | 58 | data = Shapes().datapoints_square(sample_size, sample_id) 59 | 60 | assert type(data) == pd.DataFrame 61 | assert len(data) == sample_size 62 | assert data.shape[1] == 4 63 | assert data['x'].dtype == np.float64 64 | assert data['y'].dtype == np.float64 65 | assert data['label'].dtype == int 66 | assert data['ID'].dtype == int 67 | 68 | 69 | def test_datapoints_cross(): 70 | """Test datapoints_cross.""" 71 | sample_size = 10 72 | sample_id = 1 73 | 74 | data = Shapes().datapoints_cross(sample_size, sample_id) 75 | 76 | assert type(data) == pd.DataFrame 77 | assert len(data) == sample_size 78 | assert data.shape[1] == 4 79 | assert data['x'].dtype == np.float64 80 | assert data['y'].dtype == np.float64 81 | assert data['label'].dtype == int 82 | assert data['ID'].dtype == int 83 | 84 | 85 | def test_data_shapes(): 86 | """Test data_shapes.""" 87 | sample_size = 200 88 | data = Shapes().data_shapes(sample_size) 89 | 90 | assert type(data) == pd.DataFrame 91 | assert len(pd.unique(data['ID'])) == sample_size 92 | assert data.shape[1] == 4 93 | assert data['x'].dtype == np.float64 94 | assert data['y'].dtype == np.float64 95 | assert data['label'].dtype == int 96 | assert data['ID'].dtype == int 97 | -------------------------------------------------------------------------------- /fdsa/datasets/torch_dataset.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import torch 6 | from torch.utils.data import Dataset 7 | 8 | 9 | class ToySetsDataset(Dataset): 10 | """Dataset class for loading the Shapes and Galaxy data.""" 11 | 12 | def __init__( 13 | self, 14 | csv_file: str, 15 | features: List[str], 16 | target: List[str], 17 | identifiers: List[str] = ['ID', 'OBJID'] 18 | ) -> None: 19 | """Constructor. 20 | 21 | Args: 22 | csv_file (string): Path to the csv file containing data. 23 | features (list(string)): Column names of features. 24 | target (list(string)): Column name(s) of target. 25 | identifiers (list(string), optional): Column names associated 26 | with ID. Defaults to ['ID','OBJID'] (according to galaxy data). 27 | 28 | """ 29 | self.data = pd.read_csv(csv_file, dtype=np.float32) 30 | 31 | self.identifiers = self.data[identifiers] 32 | self.x = self.data[features] 33 | self.y = self.data[target] 34 | self.unique_id = pd.unique(self.identifiers['ID']) 35 | 36 | def __len__(self) -> int: 37 | """Length of the data. 38 | 39 | Returns: 40 | int: Length of the data. 41 | """ 42 | return len(self.unique_id) 43 | 44 | def __getitem__(self, idx: torch.Tensor) -> Tuple: 45 | """Returns all elements belonging to one set. 46 | 47 | Args: 48 | idx (tensor): Index of set to be sampled. 49 | 50 | Returns: 51 | Tuple: Tuple of tensor of features and corresponding targets subset 52 | by set ID and not element ID, and length of feature tensor. 53 | """ 54 | if torch.is_tensor(idx): 55 | idx = idx.tolist() 56 | 57 | id_ = self.unique_id[idx] 58 | 59 | x = torch.from_numpy(self.x[self.identifiers['ID'].isin([id_])].values) 60 | # y = torch.from_numpy(self.y[self.identifiers['ID'].isin([id_])].values) 61 | 62 | return x 63 | 64 | 65 | class SetsDataset(Dataset): 66 | """Dataset class to load set data from a file.""" 67 | 68 | def __init__( 69 | self, 70 | dataset_path: str, 71 | device: torch.device = torch. 72 | device('cuda' if torch.cuda.is_available() else 'cpu') 73 | ) -> None: 74 | self.dataset = torch.load(dataset_path).to(device) 75 | 76 | def __len__(self): 77 | """Get length of the dataset.""" 78 | return len(self.dataset) 79 | 80 | def __getitem__(self, index): 81 | """Gets item from the dataset at the given index.""" 82 | 83 | if torch.is_tensor(index): 84 | index = index.tolist() 85 | 86 | return self.dataset[index, :, :] 87 | 88 | 89 | class Collate: 90 | """Class to pad data based on maximum set length in a batch.""" 91 | 92 | def __init__( 93 | self, 94 | max_length: int, 95 | input_dim: int, 96 | padding_value: int, 97 | device: torch.device = torch. 98 | device('cuda' if torch.cuda.is_available() else 'cpu') 99 | ) -> None: 100 | """Constructor. 101 | 102 | Args: 103 | max_length (int): Maximum length of the set as required. 104 | input_dim (int): Size of the input elements in the set. 105 | padding_value (int): Numerical value to pad the set. 106 | device (torch.device, optional): Device on which the data is stored. 107 | Defaults to CPU. 108 | """ 109 | 110 | self.max_length = max_length 111 | self.dim = input_dim 112 | self.pad_val = padding_value 113 | self.device = device 114 | 115 | def __call__(self, batch) -> Tuple: 116 | """Padding function that returns the padded sets, and true lengths of each set 117 | in the batch. 118 | 119 | Args: 120 | batch (object): Batch object from the DataLoader. 121 | 122 | Returns: 123 | Tuple: Tuple of padded tensors of sets and tensors of set lengths. 124 | """ 125 | 126 | lengths = list(map(len, batch)) 127 | batch_size = len(batch) 128 | 129 | padded_seqs = torch.full( 130 | (batch_size, self.max_length, self.dim), self.pad_val, device=self.device 131 | ) 132 | 133 | for i, l in enumerate(lengths): 134 | padded_seqs[i, 0:l, :] = batch[i][0:l, :] 135 | 136 | return padded_seqs, torch.tensor(lengths) 137 | -------------------------------------------------------------------------------- /fdsa/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PaccMann/fdsa/0e592e5281df69b3da2ea004ad3d96d9ca286f4d/fdsa/models/__init__.py -------------------------------------------------------------------------------- /fdsa/models/decoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PaccMann/fdsa/0e592e5281df69b3da2ea004ad3d96d9ca286f4d/fdsa/models/decoders/__init__.py -------------------------------------------------------------------------------- /fdsa/models/decoders/decoder_sets_ae.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import torch 4 | import torch.nn as nn 5 | from fdsa.utils.hyperparameters import RNN_CELL_FACTORY 6 | 7 | 8 | class DecoderSetsAE(nn.Module): 9 | """Decoder from Sets AutoEncoder.""" 10 | 11 | def __init__(self, **params) -> None: 12 | """Constructor. 13 | 14 | Args: 15 | params (dict): A dictionary containing parameters required 16 | to build the decoder. Example: 17 | hidden_size_encoder(int): Hidden state dimension of the encoder. 18 | Defaults to 256. 19 | input_size(int): Input feature size. Defaults to 128. 20 | hidden_size_decoder(int): Hidden state dimension of the decoder. 21 | Defaults to 256. 22 | loss(str): Loss function to optimise. Defaults to 'CrossEntropy'. 23 | cell(str): Recurrent cell type to use as a decoder. Defaults to 24 | 'pLSTM'. 25 | """ 26 | super(DecoderSetsAE, self).__init__() 27 | 28 | self.input_size = params.get('hidden_size_encoder', 256) 29 | self.output_dim = params.get('input_size', 128) 30 | self.hidden_size_decoder = params.get('hidden_size_decoder', 256) 31 | self.loss = params.get('loss', 'CrossEntropy') 32 | self.cell = params.get('cell', 'pLSTM') 33 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 34 | 35 | self.rnn = RNN_CELL_FACTORY[self.cell 36 | ](self.input_size, self.hidden_size_decoder) 37 | 38 | self.output_layer = nn.Linear(self.hidden_size_decoder, self.output_dim) 39 | 40 | if self.loss == 'CrossEntropy': 41 | self.prob_layer = nn.Linear(self.hidden_size_decoder, 2) 42 | 43 | elif self.loss == 'BCELogits': 44 | self.prob_layer = nn.Linear(self.hidden_size_decoder, 1) 45 | 46 | def forward(self, encoder_output: Tuple, length: int, p: int = 2) -> Tuple: 47 | """Executes batch processing of the decoder. 48 | 49 | Args: 50 | encoder_output (Tuple): Tuple of cell_state,hidden_state and 51 | read_vector from the last step of the encoder, all having shape 52 | [batch_size x hidden_size]. 53 | length (int): Maximum sequence length of the current batch. 54 | p (int, optional): the p-norm to use when calculating the 55 | cost matrix. Defaults to 2. 56 | 57 | Returns: 58 | Tuple: A tuple containing the mapped outputs and their member 59 | probabilities. 60 | """ 61 | 62 | if 'LSTM' in self.cell: 63 | cell_state, hidden_state, read_vector = encoder_output 64 | else: 65 | hidden_state, read_vector = encoder_output 66 | 67 | stacked_outputs = [] 68 | member_probabilities = [] 69 | 70 | read_vector0 = torch.zeros_like(read_vector) 71 | 72 | for i in range(length): 73 | 74 | if 'LSTM' in self.cell: 75 | new_hidden, new_cell = self.rnn( 76 | read_vector.to(self.device), 77 | (hidden_state.to(self.device), cell_state.to(self.device)) 78 | ) 79 | 80 | cell_state = new_cell 81 | 82 | else: 83 | new_hidden = self.rnn( 84 | read_vector.to(self.device), hidden_state.to(self.device) 85 | ) 86 | 87 | output = self.output_layer(new_hidden) 88 | member_probability = self.prob_layer(new_hidden) 89 | 90 | stacked_outputs.append(output) 91 | member_probabilities.append(member_probability) 92 | 93 | hidden_state = new_hidden 94 | read_vector = read_vector0 95 | 96 | stacked_outputs = torch.stack(stacked_outputs).permute(1, 0, 2) 97 | member_probabilities = torch.stack(member_probabilities).permute(1, 0, 2) 98 | 99 | return stacked_outputs, member_probabilities 100 | -------------------------------------------------------------------------------- /fdsa/models/decoders/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PaccMann/fdsa/0e592e5281df69b3da2ea004ad3d96d9ca286f4d/fdsa/models/decoders/tests/__init__.py -------------------------------------------------------------------------------- /fdsa/models/encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PaccMann/fdsa/0e592e5281df69b3da2ea004ad3d96d9ca286f4d/fdsa/models/encoders/__init__.py -------------------------------------------------------------------------------- /fdsa/models/encoders/deepsets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def running_global_pool_1d(inputs, pooling_type="MAX"): 6 | """Same global pool, but only for the elements up to the current element. 7 | Useful for outputs where the state of future elements is not known. 8 | Takes no mask as all elements up to the current element are assumed to exist. 9 | Currently only supports maximum. Equivalent to using a lower triangle bias. 10 | Args: 11 | inputs: A tensor of shape [batch_size, sequence_length, input_dims] 12 | containing the sequences of input vectors. 13 | pooling_type: Pooling type to use. Currently only supports 'MAX'. 14 | Returns: 15 | A tensor of shape [batch_size, sequence_length, input_dims] containing the 16 | running 'totals'. 17 | """ 18 | del pooling_type 19 | output = torch.cummax(inputs, dim=1)[0] 20 | return output 21 | 22 | 23 | def global_pool_1d(inputs, pooling_type="MAX", mask=None): 24 | """Pool elements across the last dimension. 25 | Useful to convert a list of vectors into a single vector so as 26 | to get a representation of a set. 27 | Args: 28 | inputs: A tensor of shape [batch_size, sequence_length, input_dims] 29 | containing the sequences of input vectors. 30 | pooling_type: the pooling type to use, MAX or AVR 31 | mask: A tensor of shape [batch_size, sequence_length] containing a 32 | mask for the inputs with 1's for existing elements, and 0's elsewhere. 33 | Returns: 34 | A tensor of shape [batch_size, input_dims] containing the sequences of 35 | transformed vectors. 36 | """ 37 | 38 | if mask is not None: 39 | mask = mask.unsqueeze_(2) 40 | inputs = torch.matmul(inputs, mask) 41 | 42 | if pooling_type == "MAX": 43 | output, indices = torch.max(inputs, 1, keepdim=False, out=None) 44 | 45 | elif pooling_type == "AVR": 46 | if mask is not None: 47 | 48 | output = torch.sum(inputs, 1, keepdim=False, dtype=None) 49 | 50 | num_elems = torch.sum(mask, 1, keepdim=True) 51 | 52 | output = torch.div(output, torch.max(num_elems, 1)) 53 | else: 54 | output = torch.mean(inputs, axis=1) 55 | 56 | return output 57 | 58 | 59 | def shape_list(x): 60 | """Return list of dims, statically where possible.""" 61 | x = torch.as_tensor(x) 62 | 63 | 64 | def conv_internal(conv_fn, inputs, filters, kernel_size, **kwargs): 65 | """Conditional conv_fn making kernel 1d or 2d depending on inputs shape.""" 66 | static_shape = inputs.size() 67 | if not static_shape or len(static_shape) != 4: 68 | raise ValueError( 69 | "Inputs to conv must have statically known rank 4. " 70 | "Shape: " + str(static_shape) 71 | ) 72 | 73 | if kwargs.get("padding") == "LEFT": 74 | dilation_rate = (1, 1) 75 | if "dilation_rate" in kwargs: 76 | dilation_rate = kwargs["dilation_rate"] 77 | assert kernel_size[0] % 2 == 1 and kernel_size[1] % 2 == 1 78 | height_padding = 2 * (kernel_size[0] // 2) * dilation_rate[0] 79 | if torch.equal(shape_list(inputs)[2], 1): 80 | cond_padding = torch.tensor(0) 81 | else: 82 | cond_padding = torch.tensor(2 * (kernel_size[1] // 2) * dilation_rate[1]) 83 | width_padding = 0 if static_shape[2] == 1 else cond_padding 84 | padding = (0, 0, height_padding, 0, width_padding, 0, 0, 0) 85 | inputs = nn.functional.pad(inputs, padding) 86 | 87 | inputs = inputs.view(static_shape[0], None, None, static_shape[3]) 88 | kwargs["padding"] = "VALID" 89 | 90 | def conv2d_kernel(kernel_size_arg): 91 | """Call conv2d but add suffix to name.""" 92 | 93 | result = nn.Conv2d(inputs, filters, kernel_size_arg, groups=inputs) 94 | 95 | return result 96 | 97 | return conv2d_kernel(kernel_size) 98 | 99 | 100 | def conv(inputs, filters, kernel_size, dilation_rate=(1, 1), **kwargs): 101 | 102 | def _conv2d(x, *args, **kwargs): 103 | return nn.Conv2d(*args, **kwargs)(x) 104 | 105 | return conv_internal( 106 | _conv2d, inputs, filters, kernel_size, dilation_rate=dilation_rate, **kwargs 107 | ) 108 | 109 | 110 | def conv1d(inputs, filters, kernel_size, dilation_rate=1, **kwargs): 111 | return torch.squeeze( 112 | conv( 113 | torch.expand_dims(inputs, 2), 114 | filters, (kernel_size, 1), 115 | dilation_rate=(dilation_rate, 1), 116 | **kwargs 117 | ), 2 118 | ) 119 | 120 | 121 | def linear_set_layer( 122 | layer_size, inputs, context=None, activation_fn=nn.ReLU(), dropout=0.0 123 | ): 124 | """Basic layer type for doing funky things with sets. 125 | Applies a linear transformation to each element in the input set. 126 | If a context is supplied, it is concatenated with the inputs. 127 | e.g. One can use global_pool_1d to get a representation of the set which 128 | can then be used as the context for the next layer. 129 | TODO: Add bias add (or control the biases used). 130 | Args: 131 | layer_size: Dimension to transform the input vectors to. 132 | inputs: A tensor of shape [batch_size, sequence_length, input_dims] 133 | containing the sequences of input vectors. 134 | context: A tensor of shape [batch_size, context_dims] containing a global 135 | statistic about the set. 136 | activation_fn: The activation function to use. 137 | dropout: Dropout probability. 138 | name: name. 139 | Returns: 140 | Tensor of shape [batch_size, sequence_length, output_dims] containing the 141 | sequences of transformed vectors. 142 | """ 143 | 144 | batch_size, input_dims, sequence_length = inputs.size() 145 | 146 | linear_filter = nn.Conv1d(input_dims, layer_size, 1) 147 | outputs = linear_filter(inputs) 148 | 149 | if context is not None: 150 | 151 | if len(context.get_shape().as_list()) == 2: 152 | context = torch.expand(context, axis=1) 153 | cont_tfm = conv1d(context, layer_size, 1, activation=None) 154 | outputs += cont_tfm 155 | if activation_fn is not None: 156 | outputs = activation_fn(outputs) 157 | if dropout != 0.0: 158 | outputs = nn.functional.dropout(outputs, 1.0 - dropout) 159 | return outputs 160 | 161 | 162 | def ravanbakhsh_set_layer( 163 | layer_size, 164 | inputs, 165 | mask=None, 166 | sequential=False, 167 | activation_fn=nn.Tanh(), 168 | dropout=0.0 169 | ): 170 | """Layer from Deep Sets paper: https://arxiv.org/abs/1611.04500 . 171 | More parameter-efficient version of a linear-set-layer with context. 172 | Args: 173 | layer_size: Dimension to transform the input vectors to. 174 | inputs: A tensor of shape [batch_size, sequence_length, vector] 175 | containing the sequences of input vectors. 176 | mask: A tensor of shape [batch_size, sequence_length] containing a 177 | mask for the inputs with 1's for existing elements, and 0's elsewhere. 178 | sequential: If true, will use a running global pool so each element will 179 | only depend on those before it. Set true if this layer is being used in 180 | an output sequence. 181 | activation_fn: The activation function to use. 182 | dropout: dropout. 183 | name: name. 184 | Returns: 185 | Tensor of shape [batch_size, sequence_length, vector] containing the 186 | sequences of transformed vectors. 187 | """ 188 | del dropout 189 | 190 | if sequential: 191 | return linear_set_layer( 192 | layer_size, 193 | inputs - running_global_pool_1d(inputs), 194 | activation_fn=activation_fn 195 | ) 196 | return linear_set_layer( 197 | layer_size, 198 | inputs - global_pool_1d(inputs, mask=mask).unsqueeze(1), 199 | activation_fn=activation_fn 200 | ) 201 | -------------------------------------------------------------------------------- /fdsa/models/encoders/encoder_sets_ae.py: -------------------------------------------------------------------------------- 1 | """Implementation of Set autoencoders encoder.""" 2 | from typing import Tuple 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | from fdsa.utils.hyperparameters import RNN_CELL_FACTORY 8 | 9 | 10 | class EncoderSetsAE(nn.Module): 11 | """Encoder Implementation of Sets Autoencoder""" 12 | 13 | def __init__(self, **params) -> None: 14 | """Constructor. 15 | 16 | Args: 17 | params(dict): A json file containing hyperparameters for the 18 | encoder. Example keys are: 19 | cell (str): Recurrent cell to be used as the encoder in the set 20 | autoencoder. Defaults to 'pLSTM'. 21 | input_size (int): Number of input features. Defaults to 128. 22 | hidden_size_linear (int): Number of hidden units in the linear 23 | layer. Defaults to 256. 24 | hidden_size_encoder (int): Number of hidden units in the encoder. 25 | Defaults to 256. 26 | 27 | """ 28 | super(EncoderSetsAE, self).__init__() 29 | 30 | self.input_size = params.get('input_size', 128) 31 | self.hidden_size_linear = params.get('hidden_size_linear', 256) 32 | self.hidden_size_encoder = params.get('hidden_size_encoder', 256) 33 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 34 | self.cell = params.get('cell', 'pLSTM') 35 | 36 | self.rnn = RNN_CELL_FACTORY[ 37 | self.cell](self.hidden_size_linear, self.hidden_size_encoder) 38 | 39 | # Maps each element of a set to a memory slot. 40 | self.memory_mapping = nn.Linear(self.input_size, self.hidden_size_linear) 41 | 42 | def forward(self, data: torch.Tensor, states: Tuple = None) -> Tuple: 43 | """Generates encoding for sets. 44 | 45 | Args: 46 | data (torch.Tensor): Padded data matrix of shape 47 | [batch_size, sequence_length, self.input_size]. 48 | states (Tuple): Tuple of the initial hidden state and cell state 49 | tensors, intialised at 0s. 50 | Returns: 51 | Tuple: A tuple containing the cell state, hidden state 52 | and read vector after the last element has been processed. 53 | """ 54 | 55 | batch_size, sequence_length, _ = data.size() 56 | 57 | read_vector = torch.zeros((batch_size, self.hidden_size_encoder) 58 | ).to(self.device) 59 | 60 | if states is None: 61 | if 'LSTM' in self.cell: 62 | 63 | hidden_state, cell_state = torch.zeros_like( 64 | read_vector 65 | ), torch.zeros_like(read_vector) 66 | 67 | else: 68 | 69 | hidden_state = torch.zeros_like(read_vector) 70 | 71 | memory_slots = self.memory_mapping(data) 72 | 73 | for _ in range(sequence_length): 74 | 75 | if 'LSTM' in self.cell: 76 | new_hidden_state, new_cell_state = self.rnn( 77 | read_vector, (hidden_state, cell_state) 78 | ) 79 | 80 | cell_state = new_cell_state 81 | 82 | else: 83 | new_hidden_state = self.rnn(read_vector, hidden_state) 84 | 85 | scalar_scores = torch.einsum('abc,ac->ab', (memory_slots, new_hidden_state)) 86 | 87 | attention_weights = torch.softmax(scalar_scores, dim=1) 88 | 89 | if _ == 0: 90 | assert scalar_scores.size() == torch.Size( 91 | [batch_size, sequence_length] 92 | ), 'Incorrect dimensions.' 93 | assert np.allclose( 94 | np.sum(attention_weights.detach().cpu().numpy()), batch_size 95 | ), 'Weights for each set do not sum to 1.' 96 | 97 | read_vector = torch.einsum('ab,abc->ac', 98 | (attention_weights, memory_slots)).to( 99 | self.device 100 | ) 101 | 102 | hidden_state = new_hidden_state 103 | 104 | if 'LSTM' in self.cell: 105 | return cell_state, hidden_state, read_vector 106 | else: 107 | return hidden_state, read_vector 108 | -------------------------------------------------------------------------------- /fdsa/models/encoders/tests/test_set_ae.py: -------------------------------------------------------------------------------- 1 | """Testing SetsEncoder""" 2 | import pytest 3 | import torch 4 | from fdsa.models.encoders.encoder_sets_ae import EncoderSetsAE 5 | 6 | 7 | @pytest.fixture 8 | def params(): 9 | batch_size = 5 10 | input_size = 10 11 | sequence_length = 8 12 | hidden_sizes_linear = 20 13 | hidden_sizes_encoder = 20 14 | cell = 'LSTM' 15 | return { 16 | 'cell': cell, 17 | 'batch_size': batch_size, 18 | 'input_size': input_size, 19 | 'sequence_length': sequence_length, 20 | 'hidden_size_linear': hidden_sizes_linear, 21 | 'hidden_size_encoder': hidden_sizes_encoder 22 | } 23 | 24 | 25 | def test_memory_mapping(params): 26 | """Test linear mapping to memory locations.""" 27 | 28 | input_set = torch.rand( 29 | ( 30 | params['batch_size'], params['sequence_length'], 31 | params['input_size'] 32 | ) 33 | ) 34 | set_ae = EncoderSetsAE(**params) 35 | 36 | with torch.no_grad(): 37 | memory_slots = set_ae.memory_mapping(input_set) 38 | assert memory_slots.size() == torch.Size( 39 | [ 40 | params['batch_size'], params['sequence_length'], 41 | params['hidden_size_linear'] 42 | ] 43 | ) 44 | 45 | 46 | def test_set_ae(params): 47 | """Test dimension correctness of the encoder outputs.""" 48 | 49 | input_set = torch.rand( 50 | ( 51 | params['batch_size'], params['sequence_length'], 52 | params['input_size'] 53 | ) 54 | ) 55 | set_ae = EncoderSetsAE(**params) 56 | 57 | cell_state, hidden_state, read_vector = set_ae(input_set) 58 | 59 | assert cell_state.size() == torch.Size( 60 | [params['batch_size'], params['hidden_size_encoder']] 61 | ) 62 | assert hidden_state.size() == torch.Size( 63 | [params['batch_size'], params['hidden_size_encoder']] 64 | ) 65 | 66 | assert read_vector.size() == torch.Size( 67 | [params['batch_size'], params['hidden_size_encoder']] 68 | ) 69 | -------------------------------------------------------------------------------- /fdsa/models/set_matching/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /fdsa/models/set_matching/cnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from fdsa.utils.hyperparameters import ACTIVATION_FN_FACTORY, POOLING_FN_FACTORY 4 | 5 | 6 | class CNNSetMatching(nn.Module): 7 | """Generalisable CNN module to allow for flexibility in architecture.""" 8 | 9 | def __init__(self, **params) -> None: 10 | """Constructor. 11 | 12 | Args: 13 | params (dict) containing the following keys: 14 | input_size (int): The number of input channels/ dimensions. 15 | output_channels (int): The desired number of output channels. 16 | conv_layers (int): Number of convolution layers to apply. 17 | kernel_size (List[(tuple or int)]): Size of the convolving kernel. 18 | stride (List[(tuple or int)]): Stride of the convolution. 19 | padding (List[(tuple or int)]): Zero-padding added to both sides of the 20 | input. 21 | padding_mode (str): 'zeros', 'reflect', 'replicate' or 'circular'. 22 | dilation (List[(tuple or int)]): Spacing between kernel elements. 23 | conv_activation (str): Activation to apply after convolution. 24 | See utils/hyperparameter.py for options. 25 | 26 | pooling (str): Type of pooling to apply. 27 | See utils/hyperparameter.py for options. 28 | pooling_kernel_size (List[(tuple or int)]): The size of the window to 29 | pool over. 30 | pooling_kernel_stride (List[(tuple or int)]): The stride of the window. 31 | pooling_kernel_padding (List[(tuple or int)]): Implicit zero padding to 32 | be added on both sides. 33 | pooling_kernel_dilation (List[(tuple or int)]): Controls the stride of 34 | elements in the window. 35 | 36 | fc_layers (int): Number of fully connected layers to add. 37 | fc_units (List[(int)]): List of hidden units for each 38 | fully connected layer. 39 | fc_activation (str): Activation to apply after linear transform. 40 | See utils/hyperparameter.py for options. 41 | """ 42 | super(CNNSetMatching, self).__init__() 43 | 44 | self.img_height = params['img_height'] 45 | self.img_width = params['img_width'] 46 | self.input_channel = params['input_size'] 47 | self.output_channels = params['output_channels'] 48 | self.conv_layers = params['conv_layers'] 49 | self.kernel_size = params['kernel_size'] 50 | self.stride = params['stride'] 51 | self.padding = params['padding'] 52 | self.padding_mode = params['padding_mode'] 53 | self.dilation = params['dilation'] 54 | self.conv_activation = params['conv_activation'] 55 | 56 | self.pooling = params['pooling'] 57 | self.pooling_kernel_size = params['pooling_kernel_size'] 58 | self.pooling_kernel_stride = params['pooling_kernel_stride'] 59 | self.pooling_kernel_padding = params['pooling_kernel_padding'] 60 | self.pooling_kernel_dilation = params['pooling_kernel_dilation'] 61 | 62 | self.fc_layers = params['fc_layers'] 63 | self.fc_units = params['fc_units'] 64 | self.fc_activation = params['fc_activation'] 65 | 66 | modules_conv = [] 67 | out_channels = [self.input_channel] + self.output_channels 68 | 69 | w = self.img_width 70 | h = self.img_height 71 | 72 | for layer in range(self.conv_layers): 73 | conv = nn.Conv2d( 74 | out_channels[layer], 75 | out_channels[layer + 1], 76 | self.kernel_size[layer], 77 | self.stride[layer], 78 | self.padding[layer], 79 | self.dilation[layer], 80 | padding_mode=self.padding_mode 81 | ) 82 | 83 | modules_conv.append(conv) 84 | w = self.compute_output_img_size( 85 | w, self.kernel_size[layer], self.padding[layer], self.stride[layer] 86 | ) 87 | h = self.compute_output_img_size( 88 | h, self.kernel_size[layer], self.padding[layer], self.stride[layer] 89 | ) 90 | activation = ACTIVATION_FN_FACTORY[self.conv_activation] 91 | 92 | modules_conv.append(activation) 93 | 94 | pooling = POOLING_FN_FACTORY[self.pooling]( 95 | self.pooling_kernel_size[layer], self.pooling_kernel_stride[layer], 96 | self.pooling_kernel_padding, self.pooling_kernel_dilation 97 | ) 98 | 99 | modules_conv.append(pooling) 100 | 101 | w = self.compute_output_img_size( 102 | w, self.pooling_kernel_size[layer], self.pooling_kernel_padding[layer], 103 | self.pooling_kernel_stride[layer] 104 | ) 105 | h = self.compute_output_img_size( 106 | h, self.pooling_kernel_size[layer], self.pooling_kernel_padding[layer], 107 | self.pooling_kernel_stride[layer] 108 | ) 109 | 110 | self.model_conv = nn.Sequential(*modules_conv) 111 | 112 | self.output_img_size = int(w * h * self.output_channels[-1]) 113 | 114 | linear_units = [self.output_img_size] + self.fc_units 115 | modules_linear = [] 116 | 117 | for layer in range(self.fc_layers): 118 | fc = nn.Linear(linear_units[layer], linear_units[layer + 1]) 119 | modules_linear.append(fc) 120 | if self.fc_activation is not None: 121 | modules_linear.append(ACTIVATION_FN_FACTORY[self.fc_activation]) 122 | 123 | self.model_fc = nn.Sequential(*modules_linear) 124 | 125 | def compute_output_img_size( 126 | self, input_size: int, filter_size: int, padding: int, stride: int 127 | ) -> int: 128 | """Computes the size of the output from a CNN in one dimension. 129 | 130 | Args: 131 | input_size (int): The number of input channels/ dimensions. 132 | filter_size (int): Size of the convolving kernel/pooling window. 133 | padding (int): Zero-paddings on both sides for padding number of 134 | points. 135 | stride (int): Stride of the convolution. 136 | 137 | Returns: 138 | int: Output image size along one dimension. 139 | """ 140 | return 1 + (input_size - filter_size + 2 * padding) / stride 141 | 142 | def forward(self, x: torch.Tensor) -> torch.Tensor: 143 | """Applies convolutions and specified transformations on input tensor. 144 | 145 | Args: 146 | x (torch.Tensor): Input tensor of shape 147 | [batch_size, in_channels, in_height, in_width]. 148 | 149 | Returns: 150 | torch.Tensor: Transformed tensor of shape 151 | [batch_size, out_channels, out_height, out_width] if there is no 152 | fc layer. Otherwise, [batch_size, linear_units[-1]] 153 | """ 154 | x = self.model_conv(x) 155 | x = x.view(-1, self.output_img_size) 156 | x = self.model_fc(x) 157 | 158 | return x 159 | -------------------------------------------------------------------------------- /fdsa/models/set_matching/dnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from fdsa.utils.hyperparameters import ACTIVATION_FN_FACTORY 4 | 5 | 6 | class DNNSetMatching(nn.Module): 7 | """Generalisable DNN module to allow for flexibility in architecture.""" 8 | 9 | def __init__(self, params: dict) -> None: 10 | """Constructor. 11 | 12 | Args: 13 | params (dict): DNN parameter dictionary with the following keys: 14 | input_size (int): Input tensor dimensions. 15 | fc_layers (int): Number of fully connected layers to add. 16 | fc_units (List[(int)]): List of hidden units for each layer. 17 | fc_activation (str): Activation function to apply after each 18 | fully connected layer. See utils/hyperparameter.py 19 | for options. 20 | 21 | """ 22 | super(DNNSetMatching, self).__init__() 23 | 24 | self.input_size = params['input_size'] 25 | self.layers = params['fc_layers'] 26 | self.hidden_size = params['fc_units'] 27 | self.activation = params['fc_activation'] 28 | 29 | modules = [] 30 | hidden_units = [self.input_size] + self.hidden_size 31 | for layer in range(self.layers): 32 | modules.append(nn.Linear(hidden_units[layer], hidden_units[layer + 1])) 33 | if self.activation[layer] != 'None': 34 | modules.append(ACTIVATION_FN_FACTORY[self.activation[layer]]) 35 | self.model = nn.Sequential(*modules) 36 | 37 | def forward(self, x: torch.Tensor) -> torch.Tensor: 38 | """Passes input through a feed forward neural network. 39 | 40 | Args: 41 | x (torch.Tensor): Input tensor of shape [batch_size,*,input_size] 42 | 43 | Returns: 44 | torch.Tensor: Output tensor of shape [batch_size,*, hidden_sizes[-1]]. 45 | """ 46 | 47 | return self.model(x) 48 | -------------------------------------------------------------------------------- /fdsa/models/set_matching/rnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from fdsa.utils.hyperparameters import RNN_CELL_FACTORY 5 | from fdsa.utils.hyperparameters import RNN_FACTORY 6 | from fdsa.utils.hyperparameters import ACTIVATION_FN_FACTORY 7 | from brc_pytorch.layers import MultiLayerBase 8 | from fdsa.utils.layers.select_item import SelectItem 9 | 10 | 11 | class RNNSetMatching(nn.Module): 12 | """Generalisable RNN module to allow for flexibility in architecture.""" 13 | 14 | def __init__( 15 | self, 16 | params: dict, 17 | device: torch.device = torch. 18 | device('cuda' if torch.cuda.is_available() else 'cpu') 19 | ) -> None: 20 | """Constructor. 21 | 22 | Args: 23 | params (dict) with the following keys: 24 | input_size (int): The dimension/size of the input element. 25 | Defaults to 128. 26 | max_length (int): The maximum length of the set. 27 | Defaults to 5. 28 | cell (str): RNN cell to use. One of 'LSTM','GRU','BRC' or 29 | 'nBRC'. Defaults to 'GRU'. 30 | layers (int) : The number of RNN layers required. Defaults to 1. 31 | hidden_size (int): The hidden/cell state dimensionality. 32 | Defaults to 512. 33 | bidirectional (str): 'True' or 'False' for a bidirectional RNN. 34 | Defaults to False. 35 | batch_first (str): 'True' or 'False' to indicate if batch_size 36 | comes first or set_length. Defaults to False. 37 | return_sequences (str): 'True' returns hidden state at all time 38 | steps form the last layer. 'False' returns the hidden state 39 | of the last time step from the last layer. Defaults to True. 40 | fc_layers (List[(int)]): Number of fully connected layers 41 | required after the RNN. Defaults to 2. 42 | fc_units (List[(int)]): List of hidden units for each fully 43 | connected layer. Defaults to [128,5]. 44 | fc_activation (List[(str)]): Activation function to apply after 45 | each fully connected layer. See utils/hyperparameter.py 46 | for options. Defaults to ['lrelu', 'None']. 47 | device (torch.device): Device on which the model is executed. 48 | Defaults to CPU. 49 | """ 50 | super(RNNSetMatching, self).__init__() 51 | 52 | self.input_size = params.get('input_size', 128) 53 | self.seq_len = params.get('max_length', 5) 54 | self.cell = params.get('cell', 'GRU') 55 | self.rnn_layers = params.get('layers', 1) 56 | self.hidden_size = params.get('hidden_size', 512) 57 | self.bidirectional = eval(params.get('bidirectional', 'False')) 58 | self.batch_first = eval(params.get('batch_first', 'False')) 59 | self.return_sequences = eval(params.get('return_sequences', 'True')) 60 | self.fc_layers = params.get('fc_layers', 2) 61 | self.fc_units = params.get('fc_units', [128, 5]) 62 | self.fc_activation = params.get('fc_activation', ['lrelu', 'None']) 63 | 64 | self.device = device 65 | 66 | modules_rnn = [] 67 | modules_fc = [] 68 | 69 | num_directions = 2 if self.bidirectional else 1 70 | 71 | if self.cell == 'BRC' or self.cell == 'nBRC': 72 | 73 | inner_input_dimensions = num_directions * self.hidden_size 74 | 75 | recurrent_layers = [ 76 | RNN_CELL_FACTORY[self.cell](self.input_size, self.hidden_size) 77 | ] 78 | 79 | for _ in range(self.rnn_layers - 1): 80 | recurrent_layers.append( 81 | RNN_CELL_FACTORY[self.cell] 82 | (inner_input_dimensions, self.hidden_size) 83 | ) 84 | 85 | rnn = MultiLayerBase( 86 | mode=self.cell, 87 | cells=recurrent_layers, 88 | hidden_size=self.hidden_size, 89 | batch_first=self.batch_first, 90 | bidirectional=self.bidirectional, 91 | return_sequences=self.return_sequences, 92 | device=self.device 93 | ) 94 | 95 | modules_rnn.append(rnn) 96 | 97 | if self.return_sequences: 98 | modules_rnn.append(SelectItem(0, -self.seq_len, self.batch_first)) 99 | 100 | if self.fc_layers is not None: 101 | if self.bidirectional: 102 | hidden_units = [2 * self.hidden_size] + self.fc_units 103 | else: 104 | hidden_units = [self.hidden_size] + self.fc_units 105 | 106 | for layer in range(self.fc_layers): 107 | modules_fc.append( 108 | nn.Linear(hidden_units[layer], hidden_units[layer + 1]) 109 | ) 110 | if self.fc_activation[layer] != 'None': 111 | 112 | modules_fc.append( 113 | ACTIVATION_FN_FACTORY[self.fc_activation[layer]] 114 | ) 115 | 116 | else: 117 | # need to subset if only the last time step of last layer is needed 118 | rnn = RNN_FACTORY[self.cell]( 119 | self.input_size, 120 | self.hidden_size, 121 | num_layers=self.rnn_layers, 122 | batch_first=self.batch_first, 123 | bidirectional=self.bidirectional, 124 | ) 125 | 126 | modules_rnn.append(rnn) 127 | 128 | if self.return_sequences: 129 | modules_rnn.append(SelectItem(0, -self.seq_len, self.batch_first)) 130 | else: 131 | modules_rnn.append(SelectItem(0, -1, self.batch_first)) 132 | 133 | if self.fc_layers is not None: 134 | if self.bidirectional: 135 | hidden_units = [2 * self.hidden_size] + self.fc_units 136 | else: 137 | hidden_units = [self.hidden_size] + self.fc_units 138 | 139 | for layer in range(self.fc_layers): 140 | modules_fc.append( 141 | nn.Linear(hidden_units[layer], hidden_units[layer + 1]) 142 | ) 143 | if self.fc_activation[layer] != 'None': 144 | modules_fc.append( 145 | ACTIVATION_FN_FACTORY[self.fc_activation[layer]] 146 | ) 147 | 148 | self.rnn = nn.Sequential(*modules_rnn) 149 | self.fc = nn.Sequential(*modules_fc) 150 | 151 | def forward(self, x: torch.Tensor) -> torch.Tensor: 152 | """Passes input through specified network. 153 | 154 | Args: 155 | x (torch.Tensor): Input tensor of shape [batch_size, seq_len, dim]. 156 | 157 | Returns: 158 | torch.Tensor: Output tensor of shape [batch_size, *,fc_units[-1]]. 159 | """ 160 | x = self.rnn(x) 161 | x = self.fc(x) 162 | 163 | return x 164 | -------------------------------------------------------------------------------- /fdsa/models/set_matching/selectrnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from brc_pytorch.layers import MultiLayerBase 4 | from fdsa.utils.hyperparameters import RNN_CELL_FACTORY, RNN_FACTORY 5 | 6 | 7 | class SelectRNN(nn.Module): 8 | """Allows to switch between torch implemented GRU/LSTM and in-house 9 | implementation of multilayer BRC/nBRC""" 10 | 11 | def __init__( 12 | self, 13 | params: dict, 14 | device: torch.device = torch. 15 | device('cuda' if torch.cuda.is_available() else 'cpu') 16 | ) -> None: 17 | """Constructor. 18 | 19 | Args: 20 | params (dict): Dictionary of parameters necessary for the RNN. 21 | Example: 22 | input_size (int): The dimension/size of the input element. 23 | Defaults to 128. 24 | max_length (int): The maximum length of the set. Defaults to 5. 25 | cell (str): The RNN cell to use as a decoder. Defaults to GRU. 26 | layers (int) : The number of RNN layers required. Defaults to 1. 27 | hidden_size (int): The hidden/cell state dimensionality. 28 | Defaults to 512. 29 | bidirectional (str): 'True' or 'False' for a bidirectional RNN. 30 | Defaults to False. 31 | batch_first (str): 'True' or 'False' to indicate if batch_size 32 | comes first or set_length. Defaults to False. 33 | return_sequences (str): 'True' returns hidden state at all time 34 | steps form the last layer. 'False' returns the hidden state 35 | of the last time step from the last layer. Defaults to True. 36 | device (torch.device): Device on which operations are run. 37 | Defaults to CPU. 38 | """ 39 | 40 | super(SelectRNN, self).__init__() 41 | 42 | self.input_size = params.get('input_size', 128) 43 | self.seq_len = params.get('max_length', 5) 44 | self.cell = params.get('cell', 'GRU') 45 | self.rnn_layers = params.get('layers', 1) 46 | self.hidden_size = params.get('hidden_size', 512) 47 | self.bidirectional = eval(params.get('bidirectional', 'False')) 48 | self.batch_first = eval(params.get('batch_first', 'False')) 49 | self.return_sequences = eval(params.get('return_sequences', 'True')) 50 | self.device = device 51 | 52 | num_directions = 2 if self.bidirectional else 1 53 | 54 | if self.cell == 'BRC' or self.cell == 'nBRC': 55 | 56 | recurrent_layers = [ 57 | RNN_CELL_FACTORY[self.cell](self.input_size, self.hidden_size) 58 | ] 59 | 60 | inner_input_dimensions = num_directions * self.hidden_size 61 | 62 | for _ in range(self.rnn_layers - 1): 63 | recurrent_layers.append( 64 | RNN_CELL_FACTORY[self.cell] 65 | (inner_input_dimensions, self.hidden_size) 66 | ) 67 | 68 | self.rnn = MultiLayerBase( 69 | mode=self.cell, 70 | cells=recurrent_layers, 71 | hidden_size=self.hidden_size, 72 | batch_first=self.batch_first, 73 | bidirectional=self.bidirectional, 74 | return_sequences=self.return_sequences, 75 | device=self.device 76 | ) 77 | 78 | else: 79 | 80 | self.rnn = RNN_FACTORY[self.cell]( 81 | self.input_size, 82 | self.hidden_size, 83 | num_layers=self.rnn_layers, 84 | batch_first=self.batch_first, 85 | bidirectional=self.bidirectional 86 | ) 87 | 88 | def forward(self): 89 | """Returns the chosen model.""" 90 | return self.rnn 91 | -------------------------------------------------------------------------------- /fdsa/models/set_matching/seq2seq.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import torch 4 | import torch.nn as nn 5 | from fdsa.models.set_matching.seq2seq_decoder import Seq2SeqDecoder 6 | from fdsa.models.set_matching.seq2seq_encoder import Seq2SeqEncoder 7 | 8 | 9 | class Seq2Seq(nn.Module): 10 | """Complete Sequence to Sequence Model""" 11 | 12 | def __init__( 13 | self, 14 | params: dict, 15 | device: torch.device = torch. 16 | device('cuda' if torch.cuda.is_available() else 'cpu') 17 | ): 18 | """Constructor. 19 | 20 | Args: 21 | params (dict): Dictionary containing all the parameters necessary 22 | for seq2seq encoder and decoder. See Seq2SeqEncoder and 23 | Seq2SeqDecoder for more details. 24 | device (torch.device): Device on which operations are run. 25 | Defaults to CPU. 26 | """ 27 | 28 | super(Seq2Seq, self).__init__() 29 | 30 | self.device = device 31 | self.max_length = params.get('max_length', 5) 32 | self.cell = params.get('cell', 'GRU') 33 | self.batch_first = eval(params.get('batch_first', 'False')) 34 | 35 | self.encoder = Seq2SeqEncoder(params, self.device) 36 | self.decoder = Seq2SeqDecoder(params, self.device) 37 | 38 | def forward(self, x: torch.Tensor, y: torch.Tensor) -> Tuple: 39 | """Computes output and attention weights of seq2seq model. 40 | 41 | Args: 42 | x (torch.Tensor): Input tensor to be encoded. The input set for 43 | matching is the concatenation of the two sets to be matched with 44 | a connecting token. Of the two sets, one is deemed to be the 45 | reference set while the other set is reordered accordingly. 46 | Shape : [2*set_length+1,batch_size,dim] 47 | y (torch.Tensor): The reference set for aligning the other set. 48 | Shape : [set_length,batch_size,dim] 49 | 50 | Returns: 51 | Tuple: Tuple of the outputs associated with each element of the 52 | reference set and attention weights used to compute outputs. 53 | NOTE: This model always assumes batch_first = False. 54 | TODO: Make this class, and subsequently the encoder and decoder classes 55 | flexible to handle batch_first = True as well. 56 | """ 57 | 58 | length, batch_size, dim = x.size() 59 | 60 | outputs = torch.zeros(self.max_length, batch_size, 61 | self.max_length).to(self.device) 62 | 63 | attention_weights = torch.zeros( 64 | self.max_length, dim, batch_size, self.max_length 65 | ).to(self.device) 66 | 67 | encoder_outputs, encoder_hn = self.encoder(x) 68 | 69 | decoder_hidden = encoder_hn 70 | 71 | for i in range(y.size(0)): 72 | 73 | output, hidden_state, attn_wts = self.decoder( 74 | y[i], decoder_hidden, encoder_outputs 75 | ) 76 | 77 | outputs[i] = output 78 | attention_weights[i] = attn_wts 79 | decoder_hidden = hidden_state 80 | 81 | return outputs, attention_weights 82 | -------------------------------------------------------------------------------- /fdsa/models/set_matching/seq2seq_decoder.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from typing import Tuple 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from fdsa.models.set_matching.dnn import DNNSetMatching 8 | from fdsa.models.set_matching.selectrnn import SelectRNN 9 | 10 | 11 | class Seq2SeqDecoder(nn.Module): 12 | """Decoder component of the Sequence to Sequence model.""" 13 | 14 | def __init__(self, params: dict, device: torch.device) -> None: 15 | """Constructor. 16 | 17 | Args: 18 | params (dict): Dictionary containing parameters necessary for the 19 | decoder. Example: 20 | cell (str): The RNN cell to use as a decoder. Defaults to 'GRU'. 21 | input_size (int): The dimension/size of the input element. 22 | Defaults to 128. 23 | max_length (int): The maximum length of the set. 24 | Defaults to 5. 25 | hidden_size (int): The hidden/cell state dimensionality. 26 | Defaults to 512. 27 | fc_layers (int): The number of fully connected layers. 28 | Defaults to 2. 29 | fc_units (List(int)): Number of hidden units for each FC layer. 30 | Defaults to [128,5]. 31 | fc_activation (str): Activation function to apply on each FC 32 | layer. ['lrelu', 'None']. 33 | device (torch.device): Device on which operations are run. 34 | 35 | """ 36 | 37 | super(Seq2SeqDecoder, self).__init__() 38 | 39 | self.cell = params.get('cell', 'GRU') 40 | self.input_size = params.get('input_size', 128) 41 | self.max_length = params.get('max_length', 5) 42 | 43 | self.params = deepcopy(params) 44 | 45 | self.hidden_size = params.get('hidden_size', 512) 46 | 47 | self.params_fc = dict( 48 | { 49 | 'input_size': self.hidden_size, 50 | 'fc_layers': params.get('fc_layers', 2), 51 | 'fc_units': params.get('fc_units', [128, 5]), 52 | 'fc_activation': params.get('fc_activation', ['lrelu', 'None']) 53 | } 54 | ) 55 | 56 | self.params.update({'input_size': self.hidden_size, 'bidirectional': 'False'}) 57 | 58 | self.model = SelectRNN(self.params, device)() 59 | 60 | self.attention = nn.Sequential( 61 | nn.Linear(self.hidden_size + self.input_size, self.max_length), 62 | nn.Softmax(1) 63 | ) 64 | self.add_attention = nn.Linear( 65 | 2 * self.hidden_size + self.input_size, self.hidden_size 66 | ) 67 | 68 | self.fc = DNNSetMatching(self.params_fc) 69 | 70 | def forward( 71 | self, y: torch.Tensor, hidden_state: torch.Tensor, encoder_outputs: torch.Tensor 72 | ) -> Tuple: 73 | """Forward pass of decoder. 74 | 75 | Args: 76 | y (torch.Tensor): Element of reference set. Shape: [1,batch_size,dim] 77 | hidden_state (torch.Tensor): The last hidden state from the encoder. 78 | Shape: [num_layers * num_directions, batch_size, hidden_size] 79 | encoder_outputs (torch.Tensor): All hidden states from the encoder. 80 | Shape:[set_length, batch_size, num_directions * hidden_size] 81 | 82 | Returns: 83 | Tuple: Tuple containing the output, new hidden state and attention 84 | weights used to compute the output for that element. 85 | """ 86 | 87 | if 'LSTM' in self.cell: 88 | dec_hidden = ( 89 | hidden_state[0][0].unsqueeze(0), hidden_state[1][0].unsqueeze(0) 90 | ) 91 | hidden_state_n = hidden_state[0][-1].squeeze(0) 92 | else: 93 | dec_hidden = hidden_state[0].unsqueeze(0) 94 | hidden_state_n = hidden_state[-1].squeeze(0) 95 | 96 | attention_weights = self.attention(torch.cat((y, hidden_state_n), 97 | 1)).unsqueeze(2).permute(0, 2, 1) 98 | 99 | attn_applied = torch.matmul( 100 | attention_weights, 101 | encoder_outputs[-self.max_length:, :, :].permute(1, 0, 2) 102 | ) 103 | 104 | output = torch.cat((y, attn_applied.squeeze_()), 1) 105 | 106 | output = self.add_attention(output).unsqueeze_(0) 107 | 108 | output = F.relu(output) 109 | 110 | output, hidden_state = self.model(output, dec_hidden) 111 | 112 | output = self.fc(output) 113 | 114 | return output, hidden_state, attention_weights.permute(1, 0, 2) 115 | -------------------------------------------------------------------------------- /fdsa/models/set_matching/seq2seq_encoder.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from typing import Tuple 3 | 4 | import torch 5 | import torch.nn as nn 6 | from fdsa.models.set_matching.selectrnn import SelectRNN 7 | 8 | 9 | class Seq2SeqEncoder(nn.Module): 10 | """Encoder component of Sequence to Sequence model.""" 11 | 12 | def __init__(self, params: dict, device: torch.device) -> None: 13 | """Constructor. 14 | 15 | Args: 16 | params (dict): Dictionary containing all the parameters necessary 17 | for the encoder. Example: 18 | cell (str): The RNN cell to use as a decoder. 19 | input_size (int): The dimension/size of the input element. 20 | max_length (int): The maximum length of the set. 21 | hidden_size (int): The hidden/cell state dimensionality. 22 | bidirectional (str): True if RNN should be bidirectional. 23 | device (torch.device): Device on which operations are run. 24 | """ 25 | 26 | super(Seq2SeqEncoder, self).__init__() 27 | self.parameters = deepcopy(params) 28 | self.parameters.update({'max_length': 2 * params['max_length'] + 1}) 29 | self.parameters.update({'bidirectional': 'True'}) 30 | self.model = SelectRNN(self.parameters, device)() 31 | 32 | def forward(self, x: torch.Tensor) -> Tuple: 33 | """Forward pass of encoder. 34 | 35 | Args: 36 | x (torch.Tensor): Input tensor to be encoded. 37 | Shape: [input_length, batch_size, dim]. The tensor passed as 38 | input is the concatenation of the two sets connected by a unit 39 | length tensor of the same batch size and dim. 40 | 41 | Returns: 42 | Tuple: Tuple of the hidden state from all steps and hidden state 43 | from the last step only. 44 | """ 45 | 46 | return self.model(x) 47 | -------------------------------------------------------------------------------- /fdsa/models/sets_autoencoder.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import torch 4 | import torch.nn as nn 5 | from fdsa.models.decoders.decoder_sets_ae import DecoderSetsAE 6 | from fdsa.models.encoders.encoder_sets_ae import EncoderSetsAE 7 | 8 | 9 | class SetsAE(nn.Module): 10 | 11 | def __init__( 12 | self, 13 | device: torch.device = torch. 14 | device('cuda' if torch.cuda.is_available() else 'cpu'), 15 | **params 16 | ) -> None: 17 | """Constructor. 18 | 19 | Args: 20 | device (torch.device): Device on which to run the model. 21 | Defaults to CPU. 22 | params (dict): Dictionary of parameters to pass into the encoder and 23 | decoder. See EncoderSetsAE and DecoderSetsAE for examples. 24 | """ 25 | super().__init__() 26 | self.encoder = EncoderSetsAE(**params) 27 | 28 | self.decoder = DecoderSetsAE(**params) 29 | 30 | def forward( 31 | self, inputs: torch.Tensor, max_length: int, batch_lengths: torch.Tensor 32 | ) -> Tuple: 33 | """Forward pass of the Set AutoEncoder. 34 | 35 | Args: 36 | inputs (torch.Tensor): Input tensor of shape [batch_size, max_length, dim] 37 | max_length (int): Maximum set length of the curent input tensor. 38 | batch_lengths (torch.Tensor): Lengths of all sets in the batch. 39 | 40 | Returns: 41 | Tuple: Tuple containing the predicted outputs and its probabilities. 42 | """ 43 | encoder_output = self.encoder(inputs) 44 | 45 | output, probablities = self.decoder(encoder_output, max_length) 46 | 47 | return output, probablities 48 | -------------------------------------------------------------------------------- /fdsa/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PaccMann/fdsa/0e592e5281df69b3da2ea004ad3d96d9ca286f4d/fdsa/utils/__init__.py -------------------------------------------------------------------------------- /fdsa/utils/gale_shapley.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import torch 4 | 5 | 6 | class GaleShapley: 7 | """2D implementation of Gale-Shapley""" 8 | 9 | def __init__(self, input_length: int, output_length: int) -> None: 10 | """Constructor. 11 | 12 | Args: 13 | input_length (int): Number of elements in the first set. 14 | output_length (int): Number of elements in the second set. 15 | """ 16 | 17 | self.inp_len = input_length 18 | self.out_len = output_length 19 | 20 | def get_ranking(self, pairwise_distance: torch.Tensor) -> Tuple: 21 | """Retrieve the ranks of elements in order of preference of each set in 22 | the set pair. 23 | 24 | Args: 25 | pairwise_distance (torch.Tensor): A 2D tensor of pairwise distances 26 | between the elements of the two sets. 27 | 28 | Returns: 29 | Tuple: Tuple of the ranks of elements of one set in order of 30 | preference of the other set, for both sets. 31 | """ 32 | sorted_input, indices_input = torch.sort(pairwise_distance) 33 | sorted_output, indices_output = torch.sort(pairwise_distance.t()) 34 | 35 | return indices_input, indices_output 36 | 37 | def compute( 38 | self, cost_matrix: torch.Tensor 39 | ) -> torch.Tensor: # Execute for one input at a time 40 | """Compute the Gale-Shapley assignment matrix for the first set as the 41 | "proposer". 42 | 43 | Args: 44 | cost_matrix (torch.Tensor): A 2D tensor containing the costs of 45 | assigning an element of the first set to an element of the 46 | second set. The cost is usually given by the pairwise distance. 47 | 48 | Returns: 49 | torch.Tensor: Binary 2D tensor of assignments of the second set to 50 | the first set in the pair. 51 | """ 52 | 53 | match_matrix = torch.zeros((self.inp_len, self.out_len)) 54 | 55 | preference_inputs, preference_outputs = self.get_ranking(cost_matrix) 56 | 57 | singles = torch.tensor(range(self.inp_len)) 58 | 59 | while singles.nelement() != 0: 60 | 61 | for i in singles: # i is input id 62 | 63 | for j in range(self.out_len): # j is output ranking 64 | 65 | output_preferred = preference_inputs[i, j] # output id 66 | 67 | if 1 in match_matrix[:, output_preferred]: 68 | 69 | # get input id it is matched to 70 | current_match = torch.where( 71 | match_matrix[:, output_preferred] == 1 72 | )[0] 73 | 74 | # get matched input rank from output preferences 75 | 76 | rank_current_match = torch.where( 77 | preference_outputs[output_preferred, :] == current_match 78 | )[0] 79 | 80 | # get potential input rank from output preferences 81 | 82 | rank_potential = torch.where( 83 | preference_outputs[output_preferred, :] == i 84 | )[0] 85 | 86 | if rank_potential < rank_current_match: 87 | match_matrix[i, output_preferred] = 1 88 | match_matrix[current_match, output_preferred] = 0 89 | 90 | singles = torch.cat((singles, current_match)) 91 | singles = singles[singles != i] 92 | 93 | else: 94 | continue 95 | 96 | else: 97 | match_matrix[i, output_preferred] = 1 98 | singles = singles[singles != i] 99 | 100 | break 101 | 102 | return match_matrix 103 | -------------------------------------------------------------------------------- /fdsa/utils/helper.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import subprocess 4 | import sys 5 | 6 | 7 | def cpuStats(): 8 | import psutil 9 | print(sys.version) 10 | print(psutil.cpu_percent()) 11 | print(psutil.virtual_memory()) # physical memory usage 12 | pid = os.getpid() 13 | py = psutil.Process(pid) 14 | memoryUse = py.memory_info()[0] / 2.**30 # memory use in GB...I think 15 | print('memory GB:', memoryUse) 16 | 17 | 18 | def get_gpu_memory_map(): 19 | """Get the current gpu usage. 20 | 21 | Returns 22 | ------- 23 | usage: dict 24 | Keys are device ids as integers. 25 | Values are memory usage as integers in MB. 26 | """ 27 | result = subprocess.check_output( 28 | ['nvidia-smi', '--query-gpu=memory.used', '--format=csv,nounits,noheader'], 29 | encoding='utf-8' 30 | ) 31 | # Convert lines into a dictionary 32 | gpu_memory = [int(x) for x in result.strip().split('\n')] 33 | gpu_memory_map = dict(zip(range(len(gpu_memory)), gpu_memory)) 34 | return gpu_memory_map 35 | 36 | 37 | def setup_logger(name, log_file, level=logging.INFO): 38 | """To setup as many loggers as you want""" 39 | 40 | handler = logging.FileHandler(log_file) 41 | 42 | logger = logging.getLogger(name) 43 | logger.setLevel(level) 44 | logger.addHandler(handler) 45 | 46 | return logger 47 | -------------------------------------------------------------------------------- /fdsa/utils/hyperparameters.py: -------------------------------------------------------------------------------- 1 | """Script containing factories for various hyperparameters""" 2 | import torch 3 | import torch.distributions as d 4 | import torch.nn as nn 5 | from brc_pytorch.layers import ( 6 | BistableRecurrentCell, NeuromodulatedBistableRecurrentCell 7 | ) 8 | from fdsa.utils.layers.peephole_lstm import PeepholeLSTMCell 9 | from fdsa.utils.mapper import MapperSetsAE 10 | from pytoda.datasets.utils.wrappers import WrapperCDist, WrapperKLDiv 11 | 12 | METRIC_FUNCTION_FACTORY = {'p-norm': WrapperCDist, 'KL': WrapperKLDiv} 13 | 14 | MAPPER_FUNCTION_FACTORY = { 15 | 'HM': MapperSetsAE('HM').get_assignment_matrix_hm, 16 | 'GS': MapperSetsAE('GS').get_assignment_matrix_gs 17 | } 18 | 19 | DISTRIBUTION_FUNCTION_FACTORY = { 20 | 'normal': d.normal.Normal, 21 | 'multinormal': d.multivariate_normal.MultivariateNormal, 22 | 'beta': d.beta.Beta, 23 | 'uniform': d.uniform.Uniform, 24 | 'bernoulli': d.bernoulli.Bernoulli 25 | } 26 | 27 | RNN_CELL_FACTORY = { 28 | 'LSTM': nn.LSTMCell, 29 | 'GRU': nn.GRUCell, 30 | 'BRC': BistableRecurrentCell, 31 | 'nBRC': NeuromodulatedBistableRecurrentCell, 32 | 'pLSTM': PeepholeLSTMCell 33 | } 34 | 35 | RNN_FACTORY = {'LSTM': nn.LSTM, 'GRU': nn.GRU} 36 | 37 | ACTIVATION_FN_FACTORY = { 38 | 'relu': nn.ReLU(), 39 | 'sigmoid': nn.Sigmoid(), 40 | 'selu': nn.SELU(), 41 | 'tanh': nn.Tanh(), 42 | 'lrelu': nn.LeakyReLU(), 43 | 'elu': nn.ELU(), 44 | 'softmax1': nn.Softmax(dim=1), 45 | 'softmax2': nn.Softmax(dim=2) 46 | } 47 | 48 | POOLING_FN_FACTORY = { 49 | 'avg': nn.AvgPool2d, 50 | 'adaptive_avg': nn.AdaptiveAvgPool2d, 51 | 'max': nn.MaxPool2d 52 | } 53 | 54 | LR_SCHEDULER_FACTORY = { 55 | 'step': torch.optim.lr_scheduler.StepLR, 56 | 'exp': torch.optim.lr_scheduler.ExponentialLR, 57 | 'plateau': torch.optim.lr_scheduler.ReduceLROnPlateau 58 | } 59 | -------------------------------------------------------------------------------- /fdsa/utils/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .peephole_lstm import PeepholeLSTMCell # noqa 2 | from .select_item import SelectItem # noqa -------------------------------------------------------------------------------- /fdsa/utils/layers/peephole_lstm.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Tuple 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class PeepholeLSTMCell(nn.Module): 9 | """LSTM with peephole connections.""" 10 | 11 | def __init__( 12 | self, input_size: int, hidden_size: int, bias: bool = True, *args, **kwargs 13 | ) -> None: 14 | """Constructor. 15 | 16 | Args: 17 | input_size (int): Number of input features. 18 | hidden_sizes (dict): Number of hidden units in linear and 19 | encoder layers. 20 | bias (bool): Whether to include bias. Defaults to True. 21 | """ 22 | super(PeepholeLSTMCell, self).__init__(*args, **kwargs) 23 | 24 | self.input_size = input_size 25 | self.hidden_size = hidden_size 26 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 27 | 28 | self.weights_x = nn.Parameter( 29 | torch.FloatTensor(self.input_size, self.hidden_size * 4) 30 | ) 31 | self.weights_h = nn.Parameter( 32 | torch.FloatTensor(self.hidden_size, self.hidden_size * 4) 33 | ) 34 | self.weights_c = nn.Parameter( 35 | torch.FloatTensor(self.hidden_size, self.hidden_size * 3) 36 | ) 37 | 38 | if bias: 39 | self.bias = nn.Parameter(torch.FloatTensor(self.hidden_size * 4)) 40 | else: 41 | self.bias = torch.zeros((self.hidden_size * 4)) 42 | 43 | self.init_params() 44 | 45 | def init_params(self) -> None: 46 | """Uniform Xavier initialisation of weights.""" 47 | 48 | std_dev = math.sqrt(1 / self.hidden_size) 49 | for weight in self.parameters(): 50 | weight.data.uniform_(-std_dev, std_dev) 51 | 52 | def init_hidden(self, batch_size: int) -> Tuple: 53 | """Initialise hidden and cell states. 54 | 55 | Args: 56 | batch_size (int) : batch size used for training. 57 | 58 | Returns: 59 | Tuple: a tuple containing the hidden state and cell state 60 | both initialized to zeros. 61 | """ 62 | hidden_state = torch.FloatTensor(torch.zeros((batch_size, self.hidden_size)) 63 | ).to(self.device) 64 | cell_state = torch.FloatTensor(torch.zeros((batch_size, self.hidden_size)) 65 | ).to(self.device) 66 | 67 | return hidden_state, cell_state 68 | 69 | def forward(self, data_t: torch.Tensor, states: Tuple) -> Tuple: 70 | """Single LSTM cell. 71 | 72 | Args: 73 | data_t (torch.Tensor): Element at step t with shape 74 | [batch_size, self.input_size] 75 | states (Tuple): Tuple of the internal states of the recurrent cell: 76 | hidden_state (torch.Tensor): Hidden state of the LSTM cell of 77 | shape [batch_size, self.hidden_size] 78 | cell_state (torch.Tensor): Cell state of the LSTM cell of shape 79 | [batch_size, self.hidden_size] 80 | 81 | Returns: 82 | Tuple: A tuple containing the hidden state and cell state 83 | after the element is processed. 84 | """ 85 | hidden_state, cell_state = states 86 | 87 | linear_xh = torch.matmul(data_t, self.weights_x) + torch.matmul( 88 | hidden_state, self.weights_h 89 | ) + self.bias 90 | 91 | linear_cxh = linear_xh[:, :self.hidden_size * 2] + torch.matmul( 92 | cell_state, self.weights_c[:, :self.hidden_size * 2] 93 | ) 94 | 95 | forget_prob = torch.sigmoid(linear_cxh[:, :self.hidden_size]) 96 | 97 | input_prob = torch.sigmoid(linear_cxh[:, self.hidden_size:self.hidden_size * 2]) 98 | 99 | candidates = torch.tanh(linear_xh[:, self.hidden_size * 2:self.hidden_size * 3]) 100 | 101 | new_cell_state = forget_prob * cell_state + input_prob * candidates 102 | 103 | linear_output = ( 104 | linear_xh[:, self.hidden_size * 3:self.hidden_size * 4] + torch.matmul( 105 | new_cell_state, 106 | self.weights_c[:, self.hidden_size * 2:self.hidden_size * 3] 107 | ) 108 | ) 109 | 110 | output_gate = torch.sigmoid(linear_output) 111 | 112 | new_hidden_state = torch.tanh(new_cell_state) * output_gate 113 | 114 | return new_hidden_state, new_cell_state 115 | -------------------------------------------------------------------------------- /fdsa/utils/layers/select_item.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class SelectItem(nn.Module): 5 | """Select output from a tuple to pass onto next function.""" 6 | 7 | def __init__( 8 | self, item_index: int, elem_idx: int = None, batch_first: bool = True 9 | ) -> None: 10 | """Item selection module that can be included in nn.Sequential(). 11 | 12 | Args: 13 | item_index (int): Index of the item to retrieve. 14 | elem_idx (int, optional): Index from which the selected item should 15 | be sliced. Depending on whether batch_first is True or False, the 16 | dimension along which it is sliced is either 0 or 1. Also, 17 | this implementation slices from the elem_idx until the end. E.g 18 | if elem_idx = -5, then the last 5 elements from the item are 19 | retained. Default: None. 20 | batch_first (bool, optional): Used as an indicator to slice along the 21 | correct dimensions using elem_idx. 22 | """ 23 | super(SelectItem, self).__init__() 24 | self._name = "selectitem" 25 | self.item_index = item_index 26 | self.elem_idx = elem_idx 27 | self.batch_first = batch_first 28 | 29 | def forward(self, inputs): 30 | """Selects item from tuple/list and returns it. 31 | 32 | Args: 33 | inputs (Tuple/List): Tuple/List from which item is to be retrieved. 34 | 35 | Returns: 36 | Tensor/Array: Tensor/Array corresponding to item_index in the inputs. 37 | """ 38 | if self.elem_idx: 39 | if self.batch_first: 40 | return inputs[self.item_index][:, self.elem_idx:, :] 41 | else: 42 | return inputs[self.item_index][self.elem_idx:, :, :] 43 | else: 44 | return inputs[self.item_index] 45 | -------------------------------------------------------------------------------- /fdsa/utils/layers/tests/test_peephole_lstm.py: -------------------------------------------------------------------------------- 1 | """Testing peephole_lstm""" 2 | import torch 3 | 4 | from fdsa.utils.layers.peephole_lstm import PeepholeLSTMCell 5 | 6 | 7 | def test_lstm_cell(): 8 | 9 | batch_size = 5 10 | input_size = 2 11 | hidden_size = 20 12 | 13 | input_set = torch.rand((batch_size, input_size)) 14 | 15 | for bias in [True, False]: 16 | 17 | lstm = PeepholeLSTMCell(input_size, hidden_size) 18 | hidden_state, cell_state = lstm.init_hidden(batch_size) 19 | 20 | assert type(input_set) == torch.Tensor 21 | 22 | hidden_state, cell_state = lstm(input_set, (hidden_state, cell_state)) 23 | 24 | assert hidden_state.size() == torch.Size([batch_size, hidden_size]) 25 | assert cell_state.size() == torch.Size([batch_size, hidden_size]) 26 | 27 | assert lstm.weights_c.size() == torch.Size([hidden_size, hidden_size * 3]) 28 | assert lstm.weights_h.size() == torch.Size([hidden_size, hidden_size * 4]) 29 | assert lstm.weights_x.size() == torch.Size([input_size, hidden_size * 4]) 30 | assert lstm.bias.size() == torch.Size([hidden_size * 4]) 31 | -------------------------------------------------------------------------------- /fdsa/utils/loss_setae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class SetAELoss(nn.Module): 6 | """Class to navigate between using a cross-entropy loss for L^{eos} or the binary 7 | logits loss from PyTorch.""" 8 | 9 | def __init__( 10 | self, 11 | loss: str, 12 | device: torch.device = torch. 13 | device('cuda' if torch.cuda.is_available() else 'cpu') 14 | ) -> None: 15 | """Constructor. 16 | 17 | Args: 18 | loss (str): Loss function to optimise. 19 | device (torch.device): Device on which the model should run. 20 | Defaults to CPU. 21 | """ 22 | super(SetAELoss,self).__init__() 23 | self.loss = loss 24 | self.device = device 25 | self.loss_dict = {'CrossEntropy': self.ce_loss, 'BCELogits': self.bce_loss} 26 | if loss not in self.loss_dict.keys(): 27 | raise NameError( 28 | f'Invalid loss ({loss}). Choose one from {self.loss_dict.keys()}' 29 | ) 30 | 31 | def rmse_loss( 32 | self, inputs: torch.Tensor, mapped_outputs: torch.Tensor 33 | ) -> torch.Tensor: 34 | """Computes the similarity loss using RMSE. 35 | 36 | Args: 37 | inputs (torch.Tensor): Input tensor of shape 38 | [batch_size x sequence_length x input_size]. 39 | mapped_outputs (torch.Tensor): Outputs ordered in correspondence with 40 | the inputs. 41 | 42 | Returns: 43 | torch.Tensor: Similarity loss value for the given batch. 44 | NOTE: loss assumes batch_first is True. 45 | """ 46 | loss_sim = nn.MSELoss() 47 | loss_similarity = torch.sqrt(loss_sim(inputs, mapped_outputs)) 48 | 49 | return loss_similarity 50 | 51 | def ce_loss( 52 | self, member_probabilities: torch.Tensor, batch_lengths: torch.Tensor 53 | ) -> torch.Tensor: 54 | """Computes the cross-entropy loss function for member probability. 55 | 56 | Args: 57 | member_probabilities (torch.Tensor): Probability that an output is a 58 | member of the set. 59 | batch_lengths (torch.Tensor): Tensor of lengths of each set in the 60 | batch. 61 | 62 | Returns: 63 | torch.Tensor: L^{eos} value for the given batch. 64 | NOTE: loss assumes batch_first is True. 65 | """ 66 | batch_size, length, dim = member_probabilities.size() 67 | 68 | loss_mem = nn.CrossEntropyLoss() 69 | true_probabilities = torch.ones_like( 70 | member_probabilities[:, :, 0], dtype=torch.int64 71 | ) 72 | mask = torch.zeros( 73 | batch_size, length + 1, dtype=true_probabilities.dtype, device=self.device 74 | ) 75 | 76 | mask[torch.arange(batch_size), batch_lengths] = 1 77 | mask = mask.cumsum(dim=1)[:, :-1] 78 | true_probabilities = true_probabilities * (1 - mask) 79 | 80 | member_probabilities = member_probabilities.permute(0, 2, 1) 81 | loss_member = loss_mem(member_probabilities, true_probabilities) 82 | 83 | return loss_member 84 | 85 | def bce_loss( 86 | self, member_probabilities: torch.Tensor, batch_lengths: torch.Tensor 87 | ) -> torch.Tensor: 88 | """Computes the cross-entropy loss function for member probability using the 89 | BCELogits Loss function. 90 | 91 | Args: 92 | member_probabilities (torch.Tensor): Probability that an output is 93 | a member of the set. 94 | batch_lengths (torch.Tensor): Tensor of lengths of each set in the 95 | batch. 96 | 97 | Returns: 98 | torch.Tensor: L^{eos} value for the given batch. 99 | NOTE: loss assumes batch_first is True. 100 | """ 101 | batch_size, length, dim = member_probabilities.size() 102 | 103 | loss_mem = nn.BCEWithLogitsLoss() 104 | true_probabilities = torch.ones_like(member_probabilities.squeeze_()) 105 | mask = torch.zeros( 106 | batch_size, length + 1, dtype=true_probabilities.dtype, device=self.device 107 | ) 108 | 109 | mask[torch.arange(batch_size), batch_lengths] = 1 110 | mask = mask.cumsum(dim=1)[:, :-1] 111 | 112 | true_probabilities = true_probabilities * (1 - mask) 113 | loss_member = loss_mem(member_probabilities, true_probabilities) 114 | 115 | return loss_member 116 | 117 | def forward( 118 | self, inputs: torch.Tensor, mapped_outputs: torch.Tensor, 119 | member_probabilities: torch.Tensor, batch_lengths: torch.Tensor 120 | ) -> torch.Tensor: 121 | """Computes the total loss by summing similarity and member probability losses. 122 | 123 | 124 | Args: 125 | inputs (torch.Tensor): Input tensor of shape 126 | [batch_size x sequence_length x input_size]. 127 | mapped_outputs (torch.Tensor): Outputs ordered in correspondence 128 | with inputs. 129 | member_probabilities (torch.Tensor): Probability that an output is 130 | a member of the set. 131 | batch_lengths (torch.Tensor): Tensor of lengths of each set in the 132 | batch. 133 | 134 | Returns: 135 | torch.Tensor: Loss value for the given batch. 136 | NOTE: loss assumes batch_first is True. 137 | """ 138 | similarity = self.rmse_loss(inputs, mapped_outputs) 139 | membership = self.loss_dict[self.loss](member_probabilities, batch_lengths) 140 | return similarity + membership 141 | -------------------------------------------------------------------------------- /fdsa/utils/loss_setmatching.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.distributions.one_hot_categorical import OneHotCategorical 5 | 6 | 7 | class SetMatchLoss(nn.Module): 8 | 9 | def __init__( 10 | self, 11 | loss: str, 12 | ce_type: str = 'row', 13 | temperature: float = None, 14 | iterations: int = None 15 | ) -> None: 16 | """Constructor. 17 | 18 | Args: 19 | loss (str): Loss function to use. See dictionary below for keys. 20 | temperature (float, optional): Temperature to apply to logits. 21 | Defaults to None. 22 | iterations (int, optional): Number of iterations for sinkhorn 23 | normalisation. Defaults to None. 24 | """ 25 | super(SetMatchLoss, self).__init__() 26 | self.loss = loss 27 | self.temp = temperature 28 | self.n = iterations 29 | self.ce_type = ce_type 30 | 31 | self.loss_func = dict( 32 | { 33 | 'ce_rowcol': self.crossentropy_rowcol, 34 | 'ce_row': self.crossentropy_row, 35 | 'ce_l1': self.crossentropy_l1, 36 | 'ce_l2': self.crossentropy_l2, 37 | 'ce_l1l2': self.crossentropy_l1l2, 38 | 'ce_l1l2_penalty': self.crossentropy_l1l2_penalty, 39 | 'kl_div': self.kl_div_loss, 40 | 'mask_loss': self.mask_loss, 41 | 'unique_max': self.unique_max_mask_loss, 42 | 'sinkhorn': self.sinkhorn_loss, 43 | 'kl_dist': self.distance_loss 44 | } 45 | ) 46 | 47 | def similarity( 48 | self, predictions: torch.Tensor, target_mask: torch.Tensor 49 | ) -> torch.Tensor: 50 | """Computes the L1-loss between predicted probabilities and true binary 51 | matrix. 52 | 53 | Args: 54 | predictions (torch.Tensor): Probability matrix with shape 55 | [batch_size,set_length,set_length]. 56 | target_mask (torch.Tensor): True binary permutation matrix with 57 | shape [batch_size,set_length,set_length]. 58 | 59 | Returns: 60 | torch.Tensor: The scalar loss value. 61 | """ 62 | loss = nn.L1Loss() 63 | return loss(predictions, target_mask.float()) 64 | 65 | def crossentropy_row( 66 | self, predictions: torch.Tensor, target12: torch.Tensor, target21: torch.Tensor 67 | ) -> torch.Tensor: 68 | """Computes row-wise cross entropy loss. 69 | 70 | Args: 71 | predictions (torch.Tensor): Logits from the last FC layer with shape 72 | [batch_size,set_length,set_length]. 73 | target12 (torch.Tensor): True matching indices such that a one hot 74 | embedding produces the permutation matrix for set1 vs set2. 75 | target21 (torch.Tensor): True matching indices such that a one hot 76 | embedding produces the permutation matrix for set2 vs set1. 77 | 78 | Returns: 79 | torch.Tensor: The scalar loss value. 80 | """ 81 | 82 | row_constraint = F.log_softmax(predictions / self.temp, dim=2) 83 | 84 | row_loss = F.nll_loss(row_constraint.permute(0, 2, 1), target21) 85 | 86 | return row_loss 87 | 88 | def crossentropy_rowcol( 89 | self, predictions: torch.Tensor, target12: torch.Tensor, target21: torch.Tensor 90 | ) -> torch.Tensor: 91 | """Computes row and column-wise cross entropy loss. 92 | 93 | Args: 94 | predictions (torch.Tensor): Logits from the last FC layer with shape 95 | [batch_size,set_length,set_length]. 96 | target12 (torch.Tensor): True matching indices such that a one hot 97 | embedding produces the permutation matrix for set1 vs set2. 98 | target21 (torch.Tensor): True matching indices such that a one hot 99 | embedding produces the permutation matrix for set2 vs set1. 100 | TODO: make this modular, write a function for column cross entropy and 101 | sum row and column CE functions to get rowcol CE. 102 | 103 | Returns: 104 | torch.Tensor: The scalar loss value. 105 | """ 106 | 107 | row_constraint = F.log_softmax(predictions / self.temp, dim=2) 108 | col_constraint = F.log_softmax(predictions / self.temp, dim=1) 109 | 110 | row_loss = F.nll_loss(row_constraint.permute(0, 2, 1), target21) 111 | col_loss = F.nll_loss(col_constraint, target12) 112 | 113 | return row_loss + col_loss 114 | 115 | def crossentropy_l1( 116 | self, predictions: torch.Tensor, target12: torch.Tensor, target21: torch.Tensor 117 | ) -> torch.Tensor: 118 | """Computes row and column-wise cross entropy loss and an additional L1 119 | similarity loss between predicted row-wise probabilities and one hot 120 | embedding of target21. 121 | 122 | Args: 123 | predictions (torch.Tensor): Logits from the last FC layer with shape 124 | [batch_size,set_length,set_length]. 125 | target12 (torch.Tensor): True matching indices such that a one hot 126 | embedding produces the permutation matrix for set1 vs set2. 127 | target21 (torch.Tensor): True matching indices such that a one hot 128 | embedding produces the permutation matrix for set2 vs set1. 129 | TODO: Make similarity flexible to allow easy switching between row and 130 | column L1 loss using arguments. 131 | Returns: 132 | torch.Tensor: The scalar loss value. 133 | """ 134 | 135 | if self.ce_type == 'rowcol': 136 | ce_loss = self.crossentropy_rowcol(predictions, target12, target21) 137 | elif self.ce_type == 'row': 138 | ce_loss = self.crossentropy_row(predictions, target12, target21) 139 | 140 | true_mask21 = F.one_hot(target21).float() 141 | 142 | row_constraint = F.log_softmax(predictions / self.temp, dim=2) 143 | 144 | similarity = self.similarity(row_constraint, true_mask21) 145 | 146 | return ce_loss + similarity 147 | 148 | def crossentropy_l2( 149 | self, predictions: torch.Tensor, target12: torch.Tensor, target21: torch.Tensor 150 | ) -> torch.Tensor: 151 | """Samples from row and column probability distributions and 152 | computes the L2-norm between their sampled probabilties. Loss is a 153 | sum of crossentropy_rowcol and the L2-norm. 154 | 155 | Args: 156 | predictions (torch.Tensor): Logits from the last FC layer with shape 157 | [batch_size,set_length,set_length]. 158 | target12 (torch.Tensor): True matching indices such that a one hot 159 | embedding produces the permutation matrix for set1 vs set2. 160 | target21 (torch.Tensor): True matching indices such that a one hot 161 | embedding produces the permutation matrix for set2 vs set1. 162 | 163 | Returns: 164 | torch.Tensor: The scalar loss value. 165 | """ 166 | 167 | if self.ce_type == 'rowcol': 168 | ce_loss = self.crossentropy_rowcol(predictions, target12, target21) 169 | elif self.ce_type == 'row': 170 | ce_loss = self.crossentropy_row(predictions, target12, target21) 171 | 172 | row_softmax = F.softmax(predictions / self.temp, dim=2) 173 | col_softmax = F.softmax(predictions / self.temp, dim=1).permute(0, 2, 1) 174 | 175 | row_mask = OneHotCategorical(row_softmax).sample() 176 | 177 | col_mask = OneHotCategorical(col_softmax).sample() 178 | 179 | row_filtered = row_softmax * row_mask 180 | col_filtered = col_softmax * col_mask 181 | 182 | loss = nn.MSELoss() 183 | 184 | l2_norm = torch.sqrt(loss(row_filtered, col_filtered.permute(0, 2, 1))) 185 | 186 | return ce_loss + l2_norm 187 | 188 | def crossentropy_l1l2( 189 | self, predictions: torch.Tensor, target12: torch.Tensor, target21: torch.Tensor 190 | ) -> torch.Tensor: 191 | """Computes L1 and L2 penalties in addition to cross entropy loss. 192 | 193 | Args: 194 | predictions (torch.Tensor): Logits from the last FC layer with shape 195 | [batch_size,set_length,set_length]. 196 | target12 (torch.Tensor): True matching indices such that a one hot 197 | embedding produces the permutation matrix for set1 vs set2. 198 | target21 (torch.Tensor): True matching indices such that a one hot 199 | embedding produces the permutation matrix for set2 vs set1. 200 | 201 | Returns: 202 | torch.Tensor: The scalar loss value. 203 | """ 204 | 205 | ce_loss = self.crossentropy_l2(predictions, target12, target21) 206 | 207 | true_mask21 = F.one_hot(target21).float() 208 | 209 | row_constraint = F.log_softmax(predictions / self.temp, dim=2) 210 | 211 | similarity = self.similarity(row_constraint, true_mask21) 212 | 213 | return ce_loss + similarity 214 | 215 | def kl_div_loss( 216 | self, predictions: torch.Tensor, target12: torch.Tensor, target21: torch.Tensor 217 | ) -> torch.Tensor: 218 | """Computes the KL-Divergence between log softmax of logits and binary 219 | matrix of true targets both row and column-wise. 220 | 221 | Args: 222 | predictions (torch.Tensor): Logits from the last FC layer with shape 223 | [batch_size,set_length,set_length]. 224 | target12 (torch.Tensor): True matching indices such that a one hot 225 | embedding produces the permutation matrix for set1 vs set2. 226 | target21 (torch.Tensor): True matching indices such that a one hot 227 | embedding produces the permutation matrix for set2 vs set1. 228 | 229 | Returns: 230 | torch.Tensor: The scalar loss value. 231 | """ 232 | 233 | # output vs input 234 | row_constraint = F.log_softmax(predictions, dim=2) 235 | # input vs output 236 | col_constraint = F.log_softmax(predictions, dim=1).permute(0, 2, 1) 237 | 238 | # input vs output permutation matrix 239 | one_hot_target12 = F.one_hot(target12).type(torch.float32) 240 | # output vs input permutation matrix 241 | one_hot_target21 = F.one_hot(target21).type(torch.float32) 242 | 243 | loss = nn.KLDivLoss(reduction='batchmean') 244 | 245 | return loss(row_constraint, 246 | one_hot_target21) + loss(col_constraint, one_hot_target12) 247 | 248 | def mask_loss( 249 | self, predictions: torch.Tensor, target12: torch.Tensor, target21: torch.Tensor 250 | ) -> torch.Tensor: 251 | """Computes the row wise Cross Entropy Loss and the L1-norm between the 252 | the column-wise sum of the predicted permutation matrix and a 253 | vector of ones. 254 | 255 | Args: 256 | predictions (torch.Tensor): Logits from the last FC layer with shape 257 | [batch_size,set_length,set_length]. 258 | target12 (torch.Tensor): True matching indices such that a one hot 259 | embedding produces the permutation matrix for set1 vs set2. 260 | target21 (torch.Tensor): True matching indices such that a one hot 261 | embedding produces the permutation matrix for set2 vs set1. 262 | 263 | Returns: 264 | torch.Tensor: The scalar loss value. 265 | """ 266 | true_mask = F.one_hot(target21) 267 | softmax_predictions = F.softmax(predictions, 2) 268 | 269 | softmax_argmax = torch.argmax(softmax_predictions, 2) 270 | pred_mask = F.one_hot(softmax_argmax, true_mask.size(2)).float() 271 | 272 | loss_fn = nn.CrossEntropyLoss() 273 | 274 | ce = loss_fn(predictions.permute(0, 2, 1), target21) 275 | col_constraint = torch.ones( 276 | true_mask.size(0), true_mask.size(2), device=target21.device 277 | ) 278 | 279 | constrained_loss = nn.L1Loss()(torch.sum(pred_mask, dim=1), col_constraint) 280 | 281 | return ce + constrained_loss 282 | 283 | def sinkhorn_normalisation(self, predictions: torch.Tensor) -> torch.Tensor: 284 | """Returns a doubly stochastic matrix by Sinkhorn-Knopp normalisation. 285 | 286 | Args: 287 | predictions (torch.Tensor): Logits from the last FC layer with shape 288 | [batch_size,set_length,set_length]. 289 | 290 | Returns: 291 | torch.Tensor: Doubly stochastic matrix (DSM) such that row and 292 | column sum to 1. Shape: [batch_size,set_length,set_length]. 293 | """ 294 | 295 | def row_norm(predictions): 296 | """Performs row-wise normalisation.""" 297 | 298 | return F.normalize(predictions, p=1, dim=2) 299 | 300 | def col_norm(predictions): 301 | """Performs column-wise normalisation.""" 302 | 303 | return F.normalize(predictions, p=1, dim=1) 304 | 305 | positive_matrix = F.relu(predictions) 306 | sinkhorn = positive_matrix + 1e-9 307 | 308 | if self.n is not None: 309 | for i in range(self.n): 310 | sinkhorn = col_norm(row_norm(sinkhorn)) 311 | 312 | return sinkhorn 313 | 314 | def sinkhorn_loss( 315 | self, predictions: torch.Tensor, target12: torch.Tensor, target21: torch.Tensor 316 | ) -> torch.Tensor: 317 | """Computes KL-divergence between the DSM and true permutation matrix 318 | both row and column-wise. 319 | 320 | Args: 321 | predictions (torch.Tensor): Logits from the last FC layer with shape 322 | [batch_size,set_length,set_length]. 323 | target12 (torch.Tensor): True matching indices such that a one hot 324 | embedding produces the permutation matrix for set1 vs set2. 325 | target21 (torch.Tensor): True matching indices such that a one hot 326 | embedding produces the permutation matrix for set2 vs set1. 327 | 328 | Returns: 329 | torch.Tensor: The scalar loss value. 330 | """ 331 | doubly_stochastic_matrix = self.sinkhorn_normalisation(predictions) 332 | 333 | true_mask21 = F.one_hot(target21).float() 334 | true_mask12 = F.one_hot(target12).float() 335 | loss_fn = nn.KLDivLoss(reduction='batchmean') 336 | 337 | log_dsm = torch.log(doubly_stochastic_matrix) 338 | 339 | loss = loss_fn(log_dsm, 340 | true_mask21) + loss_fn(log_dsm.permute(0, 2, 1), true_mask12) 341 | 342 | return loss 343 | 344 | def l1l2_penalty(self, predictions: torch.Tensor) -> torch.Tensor: 345 | """Computes L1-L2 matrix penalty as shown in AutoShuffle Net. 346 | https://arxiv.org/pdf/1901.08624.pdf 347 | 348 | Args: 349 | predictions (torch.Tensor): Logits from the last FC layer with shape 350 | [batch_size,set_length,set_length]. 351 | 352 | Returns: 353 | torch.Tensor: The scalar penalty value. 354 | """ 355 | 356 | row_softmax = F.softmax(predictions / self.temp, dim=2) 357 | col_softmax = F.softmax(predictions / self.temp, dim=1) 358 | 359 | row_l1 = torch.norm(row_softmax, p=1, dim=2) 360 | row_l2 = torch.norm(row_softmax, p=2, dim=2) 361 | 362 | col_l1 = torch.norm(col_softmax, p=1, dim=1) 363 | col_l2 = torch.norm(col_softmax, p=2, dim=1) 364 | 365 | p = torch.sum(row_l1 - row_l2) + torch.sum(col_l1 - col_l2) 366 | 367 | return p 368 | 369 | def crossentropy_l1l2_penalty( 370 | self, predictions: torch.Tensor, target12: torch.Tensor, target21: torch.Tensor 371 | ) -> torch.Tensor: 372 | """Computes cross entropy loss in addition to the L1-L2 matrix penalty. 373 | 374 | Args: 375 | predictions (torch.Tensor): Logits from the last FC layer with shape 376 | [batch_size,set_length,set_length]. 377 | target12 (torch.Tensor): True matching indices such that a one hot 378 | embedding produces the permutation matrix for set1 vs set2. 379 | target21 (torch.Tensor): True matching indices such that a one hot 380 | embedding produces the permutation matrix for set2 vs set1. 381 | 382 | 383 | Returns: 384 | torch.Tensor: The scalar loss value. 385 | """ 386 | 387 | if self.ce_type == 'rowcol': 388 | ce_loss = self.crossentropy_rowcol(predictions, target12, target21) 389 | elif self.ce_type == 'row': 390 | ce_loss = self.crossentropy_row(predictions, target12, target21) 391 | 392 | penalty = self.l1l2_penalty(predictions) 393 | 394 | return ce_loss + penalty 395 | 396 | def unique_max_mask_loss( 397 | self, predictions: torch.Tensor, target12: torch.Tensor, target21: torch.Tensor 398 | ) -> torch.Tensor: 399 | """Combines the maximum probabilties row and column-wise into a single 400 | sparse probabilty matrix, and then uses the column-wise max to 401 | predict target12. This prediction is then evaluated using Cross 402 | Entropy Loss. 403 | 404 | Args: 405 | predictions (torch.Tensor): Logits from the last FC layer with shape 406 | [batch_size,set_length,set_length]. 407 | target12 (torch.Tensor): True matching indices such that a one hot 408 | embedding produces the permutation matrix for set1 vs set2. 409 | target21 (torch.Tensor): True matching indices such that a one hot 410 | embedding produces the permutation matrix for set2 vs set1. 411 | 412 | Returns: 413 | torch.Tensor: The scalar loss value. 414 | """ 415 | 416 | classes = target12.size()[1] 417 | row_softmax = F.softmax(predictions, 2) 418 | col_softmax = F.softmax(predictions, 1) 419 | 420 | row_argmax = torch.argmax(row_softmax, 2) 421 | col_argmax = torch.argmax(col_softmax, 1) 422 | 423 | row_mask = F.one_hot(row_argmax, classes) 424 | # Transpose to get the correct out vs in orientation 425 | col_mask = F.one_hot(col_argmax, classes).permute(0, 2, 1) 426 | 427 | combined_sparse_probabilities = row_softmax * row_mask + col_softmax * col_mask 428 | 429 | # orientation of combined tensor is out vs in 430 | predicted_class12 = torch.argmax(combined_sparse_probabilities, 1) 431 | # since prediction is along 1st dim, we need to transpose one_hot 432 | # to get the correct mask in out vs in orientation 433 | predicted_mask12 = F.one_hot(predicted_class12, classes).permute(0, 2, 1) 434 | 435 | final_predictions12 = combined_sparse_probabilities * predicted_mask12 436 | 437 | # final_predictions are of orientation out vs in where probs are along 1st dim 438 | # so no need of permuting when calculating nll loss with target12 439 | # because for multi-dim the shape reqd is batch x class_prob x seq_len 440 | 441 | ce = nn.CrossEntropyLoss() 442 | 443 | loss = ce(final_predictions12, target12) 444 | 445 | return loss 446 | 447 | def forward( 448 | self, predictions: torch.Tensor, target12: torch.Tensor, target21: torch.Tensor 449 | ) -> torch.Tensor: 450 | """Returns the desired loss value. 451 | 452 | Args: 453 | predictions (torch.Tensor): Logits from the last FC layer with shape 454 | [batch_size,set_length,set_length]. 455 | target12 (torch.Tensor): True matching indices such that a one hot 456 | embedding produces the permutation matrix for set1 vs set2. 457 | target21 (torch.Tensor): True matching indices such that a one hot 458 | embedding produces the permutation matrix for set2 vs set1. 459 | 460 | Returns: 461 | torch.Tensor: The scalar loss value. 462 | """ 463 | 464 | return self.loss_func[self.loss](predictions, target12, target21) 465 | -------------------------------------------------------------------------------- /fdsa/utils/mapper.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | # import lap 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | from fdsa.utils.gale_shapley import GaleShapley 8 | from scipy.optimize import linear_sum_assignment 9 | 10 | 11 | class MapperSetsAE(nn.Module): 12 | """Mapping Algorithm for Sets AutoEncoder.""" 13 | 14 | def __init__( 15 | self, 16 | matcher='HM', 17 | p: int = 2, 18 | device: torch.device = torch. 19 | device('cuda' if torch.cuda.is_available() else 'cpu') 20 | ) -> None: 21 | """Constructor. 22 | 23 | Args: 24 | matcher (string): The matching algorithm to use. 25 | One of 'HM' (Munkres version of the Hungarian algorithm), 26 | or 'GS' (Gale-Shapley algorithm). 27 | Defaults to 'HM'. 28 | p (int, optional): the p-norm to use when calculating the 29 | cost matrix. Defaults to 2. 30 | device (torch.device): Device on which to run the model. 31 | Defaults to CPU. 32 | """ 33 | super(MapperSetsAE, self).__init__() 34 | self.p = p 35 | self.matcher = matcher 36 | self.method = dict( 37 | { 38 | 'HM': self.get_assignment_matrix_hm, 39 | 'GS': self.get_assignment_matrix_gs 40 | } 41 | ) 42 | 43 | self.device = device 44 | 45 | def get_assignment_matrix_hm(self, cost_matrix: torch.Tensor) -> torch.Tensor: 46 | """Runs the Munkres version of the Hungarian algorithm. 47 | 48 | Args: 49 | cost_matrix (torch.Tensor): A 2-D tensor that represents the cost 50 | of matching a row (input) and column (output). Has dimensions 51 | N x M, where N is the length of inputs and M the length of 52 | outputs. 53 | 54 | Returns: 55 | Tuple: Tuple of 2-D binary matrix with the same dimensions as the 56 | cost matrix, where 1 represents a match and 0 otherwise, and 57 | row-wise nonzero indices of the matrix. 58 | """ 59 | matrix = torch.zeros_like(cost_matrix) 60 | rows, cols = linear_sum_assignment(cost_matrix.detach().cpu().numpy()) 61 | matrix[rows, cols] = 1 62 | 63 | return torch.as_tensor(matrix), cols 64 | 65 | # def get_assignment_matrix_vj( 66 | # self, cost_matrix: torch.Tensor 67 | # ) -> torch.Tensor: 68 | # """Runs the Jonker-Volgenant algorithm. 69 | 70 | # Args: 71 | # cost_matrix (torch.Tensor): A 2-D tensor that represents the cost 72 | # of matching a row (input) and column (output). Has dimensions 73 | # N x M, where N is the length of inputs and M the length of 74 | # outputs. 75 | 76 | # Returns: 77 | # Tuple: Tuple of 2-D binary matrix with the same dimensions as the 78 | # cost matrix, where 1 represents a match and 0 otherwise, and 79 | # row-wise nonzero indices of the matrix. 80 | # """ 81 | # matrix = torch.zeros_like(cost_matrix) 82 | # cost, cols, rows = lap.lapjv( 83 | # cost_matrix.detach().cpu().numpy(), extend_cost=True 84 | # ) 85 | # matrix[range(len(cols)), cols] = 1 86 | 87 | # return torch.as_tensor(matrix), cols 88 | 89 | def get_assignment_matrix_gs(self, cost_matrix: torch.Tensor) -> torch.Tensor: 90 | """Runs the Gale-Shapley Stable Marriage algorithm. 91 | 92 | Args: 93 | cost_matrix (torch.Tensor): A 2-D tensor that represents the cost 94 | of matching a row (input) and column (output). Has dimensions 95 | N x M, where N is the length of inputs and M the length of 96 | outputs. 97 | 98 | Returns: 99 | Tuple: Tuple of 2-D binary matrix with the same dimensions as the 100 | cost matrix, where 1 represents a match and 0 otherwise, and 101 | row-wise nonzero indices of the matrix. 102 | """ 103 | 104 | gs = GaleShapley(cost_matrix.size()[0], cost_matrix.size()[1]) 105 | binary_matrix = gs.compute(cost_matrix) 106 | rows, cols = np.nonzero(binary_matrix) 107 | return binary_matrix, cols 108 | 109 | def output_mapping( 110 | self, outputs: torch.Tensor, match_matrix: torch.Tensor 111 | ) -> torch.Tensor: 112 | """Orders the outputs based on the match matrix. 113 | 114 | Args: 115 | outputs (torch.Tensor): The set of outputs generated by the decoder. 116 | match_matrix (torch.Tensor): A 2-D binary matrix, where 1 117 | represents a match and 0 otherwise. 118 | Has the same dimensions as the cost matrix. 119 | Returns: 120 | torch.Tensor: Outputs ordered in correspondence with inputs. 121 | """ 122 | return (match_matrix[..., None] * outputs[None, ...]).sum(dim=1) 123 | 124 | def forward( 125 | self, inputs: torch.Tensor, stacked_outputs: torch.Tensor, 126 | member_probabilities: torch.Tensor 127 | ) -> Tuple: 128 | """Computes cost matrix and performs a matching between inputs and outputs. 129 | 130 | Args: 131 | inputs (torch.Tensor): Input tensor of shape 132 | [batch_size x sequence_length x input_size]. 133 | stacked_outputs (torch.Tensor): Reconstructed elements from the 134 | decoder with shape [batch_size, max_length, input_size]. 135 | member_probabilities (torch.Tensor): Probabilities describing the 136 | likelihood of elements belonging to the set. 137 | 138 | Returns: 139 | Tuple: Tuple of the outputs and their membership probabilities 140 | reordered with respect to the input. 141 | """ 142 | 143 | in_batch_size, input_length, input_size = inputs.size() 144 | out_batch_size, output_length, output_size = stacked_outputs.size() 145 | mapped_outputs = [] 146 | 147 | with torch.no_grad(): 148 | cost_matrices = list( 149 | map(torch.cdist, inputs, stacked_outputs, [self.p] * in_batch_size) 150 | ) 151 | 152 | match_matrices, cols = map( 153 | list, zip(*map(self.method[self.matcher], cost_matrices)) 154 | ) 155 | 156 | mapped_outputs = list(map(self.output_mapping, stacked_outputs, match_matrices)) 157 | 158 | mapped_prob = list( 159 | map(self.output_mapping, member_probabilities, match_matrices) 160 | ) 161 | 162 | return torch.stack(mapped_outputs), torch.stack(mapped_prob), np.stack(cols) 163 | -------------------------------------------------------------------------------- /fdsa/utils/setsae_setm.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | import torch.nn as nn 6 | from fdsa.models.set_matching.rnn import RNNSetMatching 7 | from fdsa.models.set_matching.seq2seq import Seq2Seq 8 | 9 | 10 | class NetworkMapperSetsAE(nn.Module): 11 | """Pre-Trained Network as a Mapper for Sets AutoEncoder.""" 12 | 13 | def __init__( 14 | self, 15 | model: str, 16 | model_path: str, 17 | params: dict, 18 | device: torch.device = torch. 19 | device('cuda' if torch.cuda.is_available() else 'cpu'), 20 | connector_value: float = 99.0, 21 | ) -> None: 22 | """Loads and initialises model to match two given sets. 23 | 24 | Args: 25 | model (str): String value indicating the architecture style of the 26 | matching network. One of 'rnn' or 'seq2seq'. 27 | model_path (str): Path to where the pre-trained matching network 28 | is saved. 29 | params (dict): Parameters necessary to initialise the matching 30 | network. This should be the same parameters used during the 31 | pre-training of the model. 32 | device (torch.device): Whether to run model on GPU or CPU. 33 | Defaults to CPU. 34 | connector_value (float, optional): Constant tensor of shape 35 | [batch size, 1, input size] that connects the two sets to be 36 | matched. Defaults to 99.0. 37 | """ 38 | super(NetworkMapperSetsAE, self).__init__() 39 | 40 | model_dict = {'seq2seq': Seq2Seq, 'rnn': RNNSetMatching} 41 | self.padding_value = params.get('padding_value', 4.0) 42 | self.batch_first = eval(params.get('batch_first', 'False')) 43 | self.device = device 44 | self.connector = torch.full( 45 | (1, params.get('input_size', 128)), connector_value 46 | ).to(device) 47 | self.model = model 48 | self.mapper = model_dict[model](params, device).to(device) 49 | 50 | checkpoint = torch.load(model_path) 51 | self.mapper.load_state_dict(checkpoint["model_state_dict"]) 52 | 53 | def output_mapping( 54 | self, outputs: torch.Tensor, match_matrix: torch.Tensor 55 | ) -> torch.Tensor: 56 | """Orders the outputs based on the match matrix. 57 | 58 | Args: 59 | outputs (torch.Tensor): The set of outputs generated by the decoder. 60 | match_matrix (torch.Tensor): A 2-D binary matrix, where 1 61 | represents a match and 0 otherwise. 62 | Has the same dimensions as the cost matrix. 63 | Returns: 64 | torch.Tensor: Outputs ordered in correspondence with inputs. 65 | """ 66 | return (match_matrix[..., None] * outputs[None, ...]).sum(dim=1) 67 | 68 | def forward( 69 | self, inputs: torch.Tensor, stacked_outputs: torch.Tensor, 70 | member_probabilities: torch.Tensor 71 | ) -> Tuple: 72 | """Computes cost matrix and performs a matching between inputs and outputs. 73 | 74 | Args: 75 | inputs (torch.Tensor): Input tensor of shape 76 | [batch_size x sequence_length x input_size]. 77 | stacked_outputs (torch.Tensor): Reconstructed elements from the 78 | decoder with shape [batch_size, max_length, input_size]. 79 | member_probabilities (torch.Tensor): Probabilities describing the 80 | likelihood of elements belonging to the set. 81 | batch_lengths (torch.Tensor): Actual set lengths of each set in the 82 | batch. 83 | max_length (int): The largest set length to which all other sets are 84 | padded to in the batch. 85 | 86 | Returns: 87 | Tuple: Tuple of the outputs and their membership probabilities 88 | reordered with respect to the input. 89 | """ 90 | 91 | in_batch_size, input_length, input_size = inputs.size() 92 | out_batch_size, output_length, output_size = stacked_outputs.size() 93 | mapped_outputs = [] 94 | 95 | with torch.no_grad(): 96 | 97 | connector_ = self.connector.expand(in_batch_size, -1, -1) 98 | 99 | x = torch.cat((inputs, connector_, stacked_outputs), dim=1).to(self.device) 100 | 101 | self.mapper.eval() 102 | 103 | if self.model == 'seq2seq': 104 | test_output, _ = self.mapper( 105 | x.permute(1, 0, 2), inputs.permute(1, 0, 2) 106 | ) 107 | else: 108 | test_output = self.mapper(x) 109 | 110 | if self.batch_first is False: 111 | test_output = test_output.permute(1, 0, 2) 112 | 113 | assignments12 = torch.argmax(test_output, 2) 114 | 115 | match_matrices = F.one_hot(assignments12, input_length).float() 116 | 117 | mapped_outputs = list(map(self.output_mapping, stacked_outputs, match_matrices)) 118 | 119 | mapped_prob = list( 120 | map(self.output_mapping, member_probabilities, match_matrices) 121 | ) 122 | 123 | return torch.stack(mapped_outputs), torch.stack(mapped_prob), assignments12 124 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pytoda @ git+https://github.com/PaccMann/paccmann_datasets@0.2.4 2 | brc-pytorch>=0.1.3 3 | numpy>=1.14.3 4 | scipy>=1.3.1 5 | torch>=1.3.0 6 | # data analysis 7 | pandas>=0.24.2,<1.0 8 | scikit-learn>=0.22.1 9 | matplotlib>=3.1.1 10 | seaborn>=0.9.0 11 | astropy==4.0.1.post1 12 | scikit-image==0.16.2 13 | imageio==2.6.1 14 | pytest==5.4.2 15 | requests>=2.23.0 16 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """Install package.""" 2 | from setuptools import setup, find_packages 3 | 4 | setup( 5 | name='fdsa', 6 | version='0.0.1', 7 | description=('Fully differentiable set autoencoder.'), 8 | long_description=open('README.md').read(), 9 | url='https://github.com/PaccMann/fdsa', 10 | author='PaccMann team', 11 | author_email=( 12 | 'nja@zurich.ibm.com, tte@zurich.ibm.com, jab@zurich.ibm.com' 13 | ), 14 | install_requires=[ 15 | 'numpy', 'pandas', 'scipy', 'torch', 'requests', 'astropy', 'scikit-image', 16 | 'brc-pytorch>=0.1.3', 'scikit-learn>=0.22.1', 17 | 'pytoda @ git+https://github.com/PaccMann/paccmann_datasets@0.2.4' 18 | ], 19 | packages=find_packages('.'), 20 | zip_safe=False, 21 | ) 22 | --------------------------------------------------------------------------------