├── .gitignore ├── LICENSE ├── README.md ├── mofreinforce ├── __init__.py ├── cli │ ├── __init__.py │ ├── download.py │ └── main.py ├── data_preparation.ipynb ├── generator │ ├── __init__.py │ ├── config_generator.py │ ├── datamodule.py │ ├── dataset.py │ ├── logs │ │ └── v0_test_seed0_from_generator │ │ │ └── version_0 │ │ │ ├── events.out.tfevents.1676286793.park.160184.0 │ │ │ └── hparams.yaml │ ├── module.py │ ├── objectives.py │ └── transformer │ │ ├── __init__.py │ │ ├── attention.py │ │ ├── layers.py │ │ └── transformer.py ├── libs │ └── selfies │ │ ├── README.md │ │ ├── __init__.py │ │ ├── bond_constraints.py │ │ ├── compatibility.py │ │ ├── constants.py │ │ ├── decoder.py │ │ ├── encoder.py │ │ ├── exceptions.py │ │ ├── grammar_rules.py │ │ ├── mol_graph.py │ │ └── utils │ │ ├── __init__.py │ │ ├── encoding_utils.py │ │ ├── matching_utils.py │ │ ├── selfies_utils.py │ │ └── smiles_utils.py ├── predictor │ ├── __init__.py │ ├── baseline_model.py │ ├── config_predictor.py │ ├── config_predictor_ex.py │ ├── datamodule.py │ ├── dataset.py │ ├── gadgets.py │ ├── heads.py │ ├── module.py │ ├── objectives.py │ └── transformer.py ├── reinforce │ ├── __init__.py │ ├── config_reinforce.py │ ├── module.py │ └── reinforce.py ├── run_generator.py ├── run_predictor.py ├── run_reinforce.py ├── tutorial.ipynb └── utils │ ├── __init__.py │ ├── download.py │ ├── gadgets.py │ ├── metrics.py │ └── module_utils.py ├── predictor.md ├── requirements.txt └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | 3 | *.txt 4 | *.out 5 | *.log 6 | *.ipynb 7 | *.sh 8 | *.npz 9 | *.tar.gz 10 | !data_preparation.ipynb 11 | 12 | .ipynb_checkpoints/ 13 | 14 | 15 | # figure 16 | figure/ 17 | 18 | # log directory 19 | predictor/logs/ 20 | predictor/result_transformer 21 | 22 | # data directory 23 | 1_data_preprocessing/ 24 | 25 | # generator 26 | generator/logs/ 27 | 28 | # reinforce 29 | reinforce/logs*/ 30 | reinforce/old_logs*/ 31 | 32 | # lib 33 | PORMAKE/ 34 | 35 | # big file size 36 | model/ 37 | data/ 38 | raw_data/ 39 | _data/ 40 | assets/ 41 | results/ 42 | test/ 43 | logs/ 44 | logs_official/ 45 | tmp/ 46 | 47 | # setuptools 48 | build/ 49 | dist/ 50 | .eggs/ 51 | *.egg-info 52 | *.eggs -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Hyunsoo Park 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Reinforcement Learning Framework For MOFs 🚀 2 | ![scheme_rl-01](https://user-images.githubusercontent.com/64190846/218362539-740997c9-d198-4e0a-89e0-3277c5b45a51.jpg) 3 | This repository is a reinforcement learning framework for Metal-Organic Frameworks (MOFs), designed to generate MOF structures with user-defined properties. 🔍 4 | 5 | The framework consists of two key components: the agent and the environment. The agent (i.e., generator) generates MOF structures by taking actions, which are then evaluated by the environment (i.e., predictor) to predict the properties of the generated MOFs. Based on the prediction, a reward is returned to the agent, which is then used to generate the next round of MOFs, continually improving the generation process. 6 | 7 | ## Installation - Get started in minutes! 🌟 8 | 9 | ### OS and hardware requirements 10 | This package requires Linux Ubuntu 20.04 or 22.04. For optimal performance, we recommend running it with GPUs. 11 | 12 | ### Dependencies 13 | This package requires Python 3.8 or higher. 14 | 15 | ### Install 16 | To install this package, install **PyTorch** (version 1.12.0 or higher) according to your environment, and then follow these steps: 17 | 18 | ``` 19 | $ git clone https://github.com/hspark1212/MOFreinforce.git 20 | $ cd MOFreinforce 21 | $ pip install -e . 22 | ``` 23 | 24 | ## Getting Started 💥 25 | 26 | ### [download pre-trained models](https://figshare.com/ndownloader/files/39472138) 27 | 28 | To train the reinforcement learning framework, you'll need to use pre-trained predictors for DAC and pre-trained generator. You can download by running the following command in the `MOFreinforce/mofreinforce` directory: 29 | 30 | ```angular2html 31 | $ mofreinforce download default 32 | ``` 33 | Once downloaded, you can find the pre-trained generator and predictor models in the `mofreinforce/model` directory. 34 | 35 | ### [Predictor](https://github.com/hspark1212/MOFreinforce/blob/master/mofreinforce/predictor) 36 |

37 | 39 | 40 | In the model directory, you'll find the pre-trained predictors `model/predictor/preditor_qkh.ckpt` and `model/predictor/predictor_selectivity.ckpt` for CO2 heat of adsorption and CO2/H2O selectivity, respectively. If you want to train your own predictor for your desired property, you can refer to [predictor.md](https://github.com/hspark1212/MOFreinforce/blob/master/predictor.md). 41 | 42 | ### [Generator](https://github.com/hspark1212/MOFreinforce/blob/master/mofreinforce/generator) 43 |

44 | 46 | 47 | The pre-trained generator, which selects a topology and a metal cluster and creates an organic linker represented by a SELFIES string, was pre-trained with about 650,000 MOFs created by PORMAKE, allowing for generating feasible MOFs. The pre-trained generator `model/generator/generator.ckpt` can be found in the model directory. 48 | 49 | ### [Reinforcement Learning](https://github.com/hspark1212/MOFreinforce/blob/master/mofreinforce/reinforce) 50 | To implement reinforcement learning with CO2 heat of adsorption, run in the `mofreinforce` directory: 51 | ```angular2html 52 | $ python run_reinforce.py with v0_qkh_round3 53 | ``` 54 | 55 | To implement reinforcement learning with CO2/H2O selectivity, run in the `mofreinforce` directory: 56 | ```angular2html 57 | $ python run_reinforce.py with v1_selectivity_round3 58 | ``` 59 | 60 | You can experiment with other parameters by modifying the [`mofreinforce/reinforce/config_reinforce.py`](https://github.com/hspark1212/MOFreinforce/blob/master/mofreinforce/reinforce/config_reinforce.py) file. You can also train the reinforcement learning with your own pre-trained predictor to generate high-performing MOFs with your defined reward function. 61 | 62 | ### testing and construction of MOFs by PORMAKE 63 | To test the reinforcement learning, run in the `mofreinforce` directory: 64 | ```angular2html 65 | $ python run_reinforce.py with v0_qkh_round3 log_dir=test test_only=True load_path=model/reinforce/best_v0_qkh_round3.ckpt 66 | ``` 67 | The optimized generators for CO2 heat of adsorption and CO2/H2O selectivity can be found in the `mofreinforce/model` directory. 68 | 69 | The generated MOFs obtained from the test set (10,000 data) can be constructed by the [PORMAKE](https://github.com/Sangwon91/PORMAKE). 70 | The details are summarized in [`tutorial.ipynb`](https://github.com/hspark1212/MOFreinforce/blob/master/mofreinforce/tutorial.ipynb) file. 71 | 72 | ## Contributing 🙌 73 | 74 | Contributions are welcome! If you have any suggestions or find any issues, please open an issue or a pull request. 75 | 76 | ## License 📄 77 | 78 | This project is licensed under the MIT License. See the `LICENSE` file for more information. 79 | -------------------------------------------------------------------------------- /mofreinforce/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | __version__ = "0.0.1" 4 | __root_dir__ = os.path.dirname(__file__) 5 | 6 | from mofreinforce import predictor, generator, reinforce 7 | 8 | __all__ = ["predictor", "generator", "reinforce", __version__] -------------------------------------------------------------------------------- /mofreinforce/cli/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hspark1212/MOFreinforce/b920084a29db9223482ddfeffd2e775b82f49e63/mofreinforce/cli/__init__.py -------------------------------------------------------------------------------- /mofreinforce/cli/download.py: -------------------------------------------------------------------------------- 1 | class CLICommand: 2 | """ 3 | Download data and pre-trained model including a generator, predictors for DAC 4 | """ 5 | 6 | @staticmethod 7 | def add_arguments(parser): 8 | from mofreinforce.utils.download import DEFAULT_PATH 9 | 10 | add = parser.add_argument 11 | add( 12 | "target", 13 | nargs="+", 14 | help="download data and pretrained models including a generator and predictors for DAC", 15 | ) 16 | add( 17 | "--outdir", 18 | "-o", 19 | help=f"The Path where downloaded data will be stored. \n" 20 | f"default : (default) {DEFAULT_PATH} \n", 21 | ) 22 | add( 23 | "--remove_tarfile", 24 | "-r", 25 | action="store_true", 26 | help="remove tar.gz file for download database", 27 | ) 28 | 29 | @staticmethod 30 | def run(args): 31 | from mofreinforce.utils.download import ( 32 | download_default, 33 | ) 34 | 35 | func_dic = { 36 | "default": download_default, 37 | } 38 | 39 | for stuff in args.target: 40 | if stuff not in func_dic.keys(): 41 | raise ValueError( 42 | f'target must be {", ".join(func_dic.keys())}, not {stuff}' 43 | ) 44 | 45 | for stuff in args.target: 46 | func = func_dic[stuff] 47 | if func.__code__.co_argcount == 1: 48 | func(args.outdir) 49 | else: 50 | func(args.outdir, args.remove_tarfile) 51 | -------------------------------------------------------------------------------- /mofreinforce/cli/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import textwrap 3 | from importlib import import_module 4 | from argparse import RawTextHelpFormatter 5 | from mofreinforce import __version__ 6 | 7 | commands = [ 8 | ("download", "mofreinforce.cli.download"), 9 | ] 10 | 11 | 12 | def main(prog="mofreinforce", version=__version__, commands=commands, args=None): 13 | parser = argparse.ArgumentParser( 14 | prog=prog, 15 | ) 16 | parser.add_argument( 17 | "--version", action="version", version="%(prog)s-{}".format(version) 18 | ) 19 | parser.add_argument("-T", "--traceback", action="store_true") 20 | subparsers = parser.add_subparsers(title="Sub-command", dest="command") 21 | 22 | subparser = subparsers.add_parser( 23 | "help", description="Help", help="Help for sub-command." 24 | ) 25 | 26 | subparser.add_argument( 27 | "helpcommand", 28 | nargs="?", 29 | metavar="sub-command", 30 | help="Provide help for sub-command.", 31 | ) 32 | 33 | functions = {} 34 | parsers = {} 35 | for command, module_name in commands: 36 | cmd = import_module(module_name).CLICommand 37 | docstring = cmd.__doc__ 38 | if docstring is None: 39 | # Backwards compatibility with GPAW 40 | short = cmd.short_description 41 | long = getattr(cmd, "description", short) 42 | else: 43 | parts = docstring.split("\n", 1) 44 | if len(parts) == 1: 45 | short = docstring 46 | long = docstring 47 | else: 48 | short, body = parts 49 | long = short 50 | # long = short + '\n' + textwrap.dedent(body) 51 | subparser = subparsers.add_parser( 52 | command, formatter_class=RawTextHelpFormatter, help=short, description=long 53 | ) 54 | cmd.add_arguments(subparser) 55 | functions[command] = cmd.run 56 | parsers[command] = subparser 57 | 58 | # if hook: 59 | # args = hook(parser, args) 60 | # args = hook(parser, args) 61 | # else: 62 | args = parser.parse_args(args) 63 | 64 | if args.command == "help": 65 | if args.helpcommand is None: 66 | parser.print_help() 67 | else: 68 | parsers[args.helpcommand].print_help() 69 | elif args.command is None: 70 | parser.print_usage() 71 | else: 72 | f = functions[args.command] 73 | try: 74 | if f.__code__.co_argcount == 1: 75 | f(args) 76 | else: 77 | f(args, parsers[args.command]) 78 | except KeyboardInterrupt: 79 | pass 80 | except Exception as x: 81 | if args.traceback: 82 | raise 83 | else: 84 | l1 = "{}: {}\n".format(x.__class__.__name__, x) 85 | l2 = "To get a full traceback, use: {} -T {} ...".format( 86 | prog, args.command 87 | ) 88 | parser.error(l1 + l2) 89 | -------------------------------------------------------------------------------- /mofreinforce/generator/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hspark1212/MOFreinforce/b920084a29db9223482ddfeffd2e775b82f49e63/mofreinforce/generator/__init__.py -------------------------------------------------------------------------------- /mofreinforce/generator/config_generator.py: -------------------------------------------------------------------------------- 1 | import json 2 | from sacred import Experiment 3 | 4 | ex = Experiment("generator") 5 | 6 | mc_to_idx = json.load(open("data/mc_to_idx.json")) 7 | topo_to_idx = json.load(open("data/topo_to_idx.json")) 8 | 9 | 10 | @ex.config 11 | def config(): 12 | seed = 0 13 | exp_name = "generator" 14 | log_dir = "generator/logs" 15 | loss_names = {"generator": 1} 16 | 17 | # datamodule 18 | dataset_dir = "data/dataset_generator" 19 | batch_size = 256 20 | num_workers = 8 # recommend num_gpus * 4 21 | max_len = 128 22 | 23 | # transformer 24 | path_topo_to_idx = "data/topo_to_idx.json" 25 | path_mc_to_idx = "data/mc_to_idx.json" 26 | path_vocab = "data/vocab_to_idx.json" 27 | # input_dim = len(vocab_to_idx) 28 | # output_dim = len(vocab_to_idx) 29 | hid_dim = 256 30 | n_layers = 3 31 | n_heads = 8 32 | pf_dim = 512 33 | dropout = 0.1 34 | max_len = 128 35 | src_pad_idx = 0 36 | trg_pad_idx = 0 37 | 38 | # Trainer 39 | per_gpu_batchsize = 128 40 | num_nodes = 1 41 | devices = 2 42 | precision = 16 43 | resume_from = None 44 | val_check_interval = 1.0 45 | test_only = False 46 | load_path = "" 47 | gradient_clip_val = None 48 | 49 | # Optimizer Setting 50 | optim_type = "adam" # adamw, adam, sgd (momentum=0.9) 51 | learning_rate = 5e-4 52 | weight_decay = 0 53 | decay_power = ( 54 | "constant" # default polynomial decay, [cosine, constant, constant_with_warmup] 55 | ) 56 | max_epochs = 100 57 | max_steps = -1 # num_data * max_epoch // batch_size (accumulate_grad_batches) 58 | warmup_steps = 0.0 # int or float ( max_steps * warmup_steps) 59 | end_lr = 0 60 | lr_mult = 1 # multiply lr for downstream heads 61 | 62 | 63 | @ex.named_config 64 | def v0(): 65 | exp_name = "v0" 66 | 67 | 68 | @ex.named_config 69 | def test(): 70 | exp_name = "test" 71 | 72 | 73 | @ex.named_config 74 | def v0_test(): 75 | exp_name = "v0_test" 76 | load_path = "model/generator.ckpt" 77 | 78 | test_only = True 79 | num_devices = 1 80 | 81 | 82 | """ 83 | old experiments 84 | """ 85 | 86 | 87 | @ex.named_config 88 | def v0_grad_clip(): 89 | exp_name = "v0_grad_clip" 90 | 91 | gradient_clip_val = 0.5 92 | -------------------------------------------------------------------------------- /mofreinforce/generator/datamodule.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from pytorch_lightning import LightningDataModule 4 | from torch.utils.data import DataLoader 5 | from generator.dataset import GeneratorDataset 6 | 7 | 8 | class GeneratorDatamodule(LightningDataModule): 9 | def __init__(self, _config): 10 | super(GeneratorDatamodule, self).__init__() 11 | self.dataset_dir = _config["dataset_dir"] 12 | self.batch_size = _config["batch_size"] 13 | self.num_workers = _config["num_workers"] 14 | self.max_len = _config["max_len"] 15 | self.path_vocab = _config["path_vocab"] 16 | self.path_topo_to_idx = _config["path_topo_to_idx"] 17 | self.path_mc_to_idx = _config["path_mc_to_idx"] 18 | 19 | @property 20 | def dataset_cls(self): 21 | return GeneratorDataset 22 | 23 | def set_train_dataset(self): 24 | self.train_dataset = self.dataset_cls( 25 | dataset_dir=self.dataset_dir, 26 | path_vocab=self.path_vocab, 27 | path_topo_to_idx=self.path_topo_to_idx, 28 | path_mc_to_idx=self.path_mc_to_idx, 29 | split="train", 30 | max_len=self.max_len, 31 | ) 32 | 33 | def set_val_dataset(self): 34 | self.val_dataset = self.dataset_cls( 35 | dataset_dir=self.dataset_dir, 36 | path_vocab=self.path_vocab, 37 | path_topo_to_idx=self.path_topo_to_idx, 38 | path_mc_to_idx=self.path_mc_to_idx, 39 | split="val", 40 | max_len=self.max_len, 41 | ) 42 | 43 | def set_test_dataset(self): 44 | self.test_dataset = self.dataset_cls( 45 | dataset_dir=self.dataset_dir, 46 | path_vocab=self.path_vocab, 47 | path_topo_to_idx=self.path_topo_to_idx, 48 | path_mc_to_idx=self.path_mc_to_idx, 49 | split="test", 50 | max_len=self.max_len, 51 | ) 52 | 53 | def setup(self, stage: Optional[str] = None): 54 | if stage in (None, "fit"): 55 | self.set_train_dataset() 56 | self.set_val_dataset() 57 | 58 | if stage in (None, "test"): 59 | self.set_test_dataset() 60 | 61 | self.collate = self.dataset_cls.collate 62 | 63 | def train_dataloader(self): 64 | return DataLoader( 65 | self.train_dataset, 66 | batch_size=self.batch_size, 67 | num_workers=self.num_workers, 68 | collate_fn=self.collate, 69 | shuffle=True, ## 70 | pin_memory=True, 71 | ) 72 | 73 | def val_dataloader(self): 74 | return DataLoader( 75 | self.val_dataset, 76 | batch_size=self.batch_size, 77 | num_workers=self.num_workers, 78 | collate_fn=self.collate, 79 | shuffle=False, ## 80 | pin_memory=True, 81 | ) 82 | 83 | def test_dataloader(self): 84 | return DataLoader( 85 | self.test_dataset, 86 | batch_size=self.batch_size, 87 | num_workers=self.num_workers, 88 | collate_fn=self.collate, 89 | shuffle=False, ## 90 | pin_memory=True, 91 | ) 92 | -------------------------------------------------------------------------------- /mofreinforce/generator/dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | from pathlib import Path 4 | import pandas as pd 5 | import libs.selfies as sf 6 | 7 | import torch 8 | from torch.utils.data import Dataset 9 | 10 | 11 | class GeneratorDataset(Dataset): 12 | def __init__( 13 | self, dataset_dir, path_vocab, path_topo_to_idx, path_mc_to_idx, split, max_len 14 | ): 15 | assert split in ["train", "test", "val"] 16 | # read vocab_to_idx 17 | self.vocab_to_idx = json.load(open(path_vocab)) 18 | self.topo_to_idx = json.load(open(path_topo_to_idx)) 19 | self.mc_to_idx = json.load(open(path_mc_to_idx)) 20 | 21 | # load dataset 22 | path_data = Path(dataset_dir, f"{split}.csv") 23 | csv_ = pd.read_csv(path_data) 24 | print(f"read file : {path_data}, num_data : {len(csv_)}") 25 | self.topo = np.array(csv_["topo"]) 26 | self.mc = np.array(csv_["mc"]) 27 | self.num_conn = np.array(csv_["num_conn"]) 28 | self.frags = np.array(csv_["frags"]) 29 | self.selfies = np.array(csv_["selfies"]) 30 | 31 | self.encoded_input, self.encoded_output = self.encoding(max_len) 32 | 33 | def encoding(self, max_len): 34 | # making encoded_input 35 | encoded_input = [] 36 | for i, f in enumerate(self.frags): 37 | encoded_frags = [self.vocab_to_idx[v] for v in sf.split_selfies(f)] 38 | encoded = ( 39 | [self.mc_to_idx[self.mc[i]]] 40 | + [self.num_conn[i]] 41 | + [self.vocab_to_idx["[SOS]"]] 42 | + encoded_frags 43 | + [self.vocab_to_idx["[EOS]"]] 44 | + [self.vocab_to_idx["[PAD]"]] * (max_len - 4 - len(encoded_frags)) 45 | ) 46 | encoded_input.append(encoded) 47 | 48 | # making encoded_output 49 | encoded_output = [] 50 | for i, f in enumerate(self.selfies): 51 | encoded_selfies = [self.vocab_to_idx[v] for v in sf.split_selfies(f)] 52 | encoded = ( 53 | [self.vocab_to_idx["[SOS]"]] 54 | + [self.topo_to_idx[self.topo[i]]] 55 | + [self.mc_to_idx[self.mc[i]]] 56 | + encoded_selfies 57 | + [self.vocab_to_idx["[EOS]"]] 58 | + [self.vocab_to_idx["[PAD]"]] * (max_len - 4 - len(encoded_selfies)) 59 | ) 60 | encoded_output.append(encoded) 61 | 62 | return encoded_input, encoded_output 63 | 64 | def __len__(self): 65 | return len(self.selfies) 66 | 67 | def __getitem__(self, idx): 68 | ret = dict() 69 | 70 | ret.update( 71 | { 72 | "topo": self.topo[idx], 73 | "mc": self.mc[idx], 74 | "frags": self.frags[idx], 75 | "selfies": self.selfies[idx], 76 | "encoded_input": self.encoded_input[idx], 77 | "encoded_output": self.encoded_output[idx], 78 | } 79 | ) 80 | return ret 81 | 82 | @staticmethod 83 | def collate(batch): 84 | keys = set([key for b in batch for key in b.keys()]) 85 | dict_batch = {k: [dic[k] if k in dic else None for dic in batch] for k in keys} 86 | dict_batch["encoded_input"] = torch.LongTensor(dict_batch["encoded_input"]) 87 | dict_batch["encoded_output"] = torch.LongTensor(dict_batch["encoded_output"]) 88 | return dict_batch 89 | -------------------------------------------------------------------------------- /mofreinforce/generator/logs/v0_test_seed0_from_generator/version_0/events.out.tfevents.1676286793.park.160184.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hspark1212/MOFreinforce/b920084a29db9223482ddfeffd2e775b82f49e63/mofreinforce/generator/logs/v0_test_seed0_from_generator/version_0/events.out.tfevents.1676286793.park.160184.0 -------------------------------------------------------------------------------- /mofreinforce/generator/logs/v0_test_seed0_from_generator/version_0/hparams.yaml: -------------------------------------------------------------------------------- 1 | config: 2 | batch_size: 256 3 | dataset_dir: data/dataset_generator 4 | decay_power: constant 5 | dropout: 0.1 6 | end_lr: 0 7 | exp_name: v0_test 8 | gradient_clip_val: null 9 | hid_dim: 256 10 | learning_rate: 0.0005 11 | load_path: model/generator.ckpt 12 | log_dir: generator/logs 13 | loss_names: 14 | generator: 1 15 | lr_mult: 1 16 | max_epochs: 100 17 | max_len: 128 18 | max_steps: -1 19 | n_heads: 8 20 | n_layers: 3 21 | num_devices: 1 22 | num_nodes: 1 23 | num_workers: 8 24 | optim_type: adam 25 | path_mc_to_idx: data/mc_to_idx.json 26 | path_topo_to_idx: data/topo_to_idx.json 27 | path_vocab: data/vocab_to_idx.json 28 | per_gpu_batchsize: 128 29 | pf_dim: 512 30 | precision: 16 31 | resume_from: null 32 | seed: 0 33 | src_pad_idx: 0 34 | test_only: true 35 | trg_pad_idx: 0 36 | val_check_interval: 1.0 37 | warmup_steps: 0.0 38 | weight_decay: 0 39 | -------------------------------------------------------------------------------- /mofreinforce/generator/module.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | from tqdm import tqdm 4 | 5 | import numpy as np 6 | import libs.selfies as sf 7 | from rdkit import Chem 8 | 9 | import torch 10 | from pytorch_lightning import LightningModule 11 | 12 | from generator import objectives 13 | from generator.transformer import Transformer 14 | from utils import module_utils 15 | 16 | from utils.metrics import Metrics 17 | from rdkit.Chem.Draw import MolToImage 18 | 19 | 20 | class Generator(LightningModule): 21 | def __init__(self, config): 22 | super(Generator, self).__init__() 23 | self.save_hyperparameters() 24 | 25 | self.max_len = config["max_len"] 26 | # topo 27 | path_topo_to_idx = config["path_topo_to_idx"] 28 | self.topo_to_idx = json.load(open(path_topo_to_idx)) 29 | self.idx_to_topo = {v: k for k, v in self.topo_to_idx.items()} 30 | # mc 31 | path_mc_to_idx = config["path_mc_to_idx"] 32 | self.mc_to_idx = json.load(open(path_mc_to_idx)) 33 | self.idx_to_mc = {v: k for k, v in self.mc_to_idx.items()} 34 | # ol 35 | path_vocab = config["path_vocab"] 36 | self.vocab_to_idx = json.load(open(path_vocab)) 37 | self.idx_to_vocab = {v: k for k, v in self.vocab_to_idx.items()} 38 | 39 | self.transformer = Transformer( 40 | input_dim=len(self.vocab_to_idx), 41 | output_dim=len(self.vocab_to_idx), 42 | topo_dim=len(self.topo_to_idx), 43 | mc_dim=len(self.mc_to_idx), 44 | hid_dim=config["hid_dim"], 45 | n_layers=config["n_layers"], 46 | n_heads=config["n_heads"], 47 | pf_dim=config["pf_dim"], 48 | dropout=config["dropout"], 49 | max_len=config["max_len"], 50 | src_pad_idx=config["src_pad_idx"], 51 | trg_pad_idx=config["trg_pad_idx"], 52 | ) 53 | 54 | module_utils.set_metrics(self) 55 | # ===================== load model ====================== 56 | 57 | if config["load_path"] != "": 58 | ckpt = torch.load(config["load_path"], map_location="cpu") 59 | state_dict = ckpt["state_dict"] 60 | self.load_state_dict(state_dict, strict=False) 61 | print(f"load model : {config['load_path']}") 62 | 63 | def infer(self, batch): 64 | src = batch["encoded_input"] # [B, max_len] 65 | 66 | tgt_input = batch["encoded_output"][:, :-1] # [B, seq_len-1] 67 | tgt_label = batch["encoded_output"][:, 1:] # [B, seq_len-1] 68 | 69 | # get mask 70 | out = self.transformer(src, tgt_input) # [B, seq_len-1, vocab_dim] 71 | 72 | out.update( 73 | { 74 | "src": src, 75 | "tgt": batch["encoded_output"], 76 | "tgt_label": tgt_label, 77 | } 78 | ) 79 | 80 | return out 81 | 82 | def evaluate(self, src, max_len=128): 83 | """ 84 | src : torch.LongTensor [B=1, seq_len] 85 | """ 86 | vocab_to_idx = self.vocab_to_idx 87 | src_mask = self.transformer.make_src_mask(src) 88 | 89 | # get encoded src 90 | enc_src = self.transformer.encoder(src, src_mask) # [B=1, seq_len, hid_dim] 91 | 92 | # get target 93 | trg_indexes = [vocab_to_idx["[SOS]"]] 94 | for i in range(max_len): 95 | trg_tensor = torch.LongTensor(trg_indexes).unsqueeze(0).to(src.device) 96 | trg_mask = self.transformer.make_trg_mask( 97 | trg_tensor 98 | ) # [B=1, 1, seq_len, seq_len] 99 | 100 | output = self.transformer.decoder( 101 | trg_tensor, enc_src, trg_mask, src_mask 102 | ) # [B=1, seq_len, vocab_dim] 103 | 104 | if "output_ol" in output.keys(): 105 | out = output["output_ol"] 106 | pred_token = out.argmax(-1)[:, -1].item() 107 | elif "output_mc" in output.keys(): 108 | out = output["output_mc"] 109 | pred_token = out.argmax(-1).item() 110 | else: 111 | out = output["output_topo"] 112 | pred_token = out.argmax(-1).item() 113 | trg_indexes.append(pred_token) 114 | 115 | if i > 3 and pred_token == vocab_to_idx["[EOS]"]: 116 | break 117 | 118 | # get topo 119 | topo_idx = trg_indexes[1] 120 | topo = self.idx_to_topo[topo_idx] 121 | # get mc 122 | mc_idx = trg_indexes[2] 123 | mc = self.idx_to_mc[mc_idx] 124 | # get ol 125 | ol_idx = trg_indexes[3:] 126 | ol_tokens = [self.idx_to_vocab[idx] for idx in ol_idx] 127 | # convert to selfies and smiles 128 | gen_sf = None 129 | gen_sm = None 130 | try: 131 | gen_sf = "".join(ol_tokens[:-1]) # remove EOS token 132 | gen_sm = sf.decoder(gen_sf) 133 | m = Chem.MolFromSmiles(gen_sm) 134 | gen_sm = Chem.MolToSmiles(m) # canonical smiles 135 | except Exception as e: 136 | print(e) 137 | pass 138 | 139 | ret = { 140 | "topo": topo, 141 | "mc": mc, 142 | "topo_idx": topo_idx, 143 | "mc_idx": mc_idx, 144 | "ol_idx": ol_idx, 145 | "gen_sf": gen_sf, 146 | "gen_sm": gen_sm, 147 | } 148 | return ret 149 | 150 | def forward(self, batch): 151 | ret = dict() 152 | ret.update(objectives.compute_loss(self, batch)) 153 | 154 | return ret 155 | 156 | def training_step(self, batch, batch_idx): 157 | ret = self(batch) 158 | total_loss = sum([v for k, v in ret.items() if "loss" in k]) 159 | 160 | return total_loss 161 | 162 | def training_epoch_end(self, outputs): 163 | module_utils.epoch_wrapup(self) 164 | 165 | def validation_step(self, batch, batch_idx): 166 | output = self(batch) 167 | 168 | def validation_epoch_end(self, output): 169 | module_utils.epoch_wrapup(self) 170 | 171 | def test_step(self, batch, batch_idx): 172 | return batch 173 | 174 | def test_epoch_end(self, batches): 175 | split = "test" 176 | module_utils.epoch_wrapup(self) 177 | 178 | metrics = Metrics(self.vocab_to_idx, self.idx_to_vocab) 179 | list_src = torch.concat([b["encoded_input"] for b in batches], dim=0) 180 | 181 | for src in tqdm(list_src): 182 | out = self.evaluate(src.unsqueeze(0)) 183 | 184 | if out["gen_sm"] is None: 185 | metrics.num_fail.append(1) 186 | continue 187 | else: 188 | metrics.num_fail.append(0) 189 | 190 | metrics.update(out, src) 191 | 192 | self.log(f"{split}/conn_match", metrics.get_mean(metrics.conn_match)) 193 | self.log(f"{split}/unique_ol", len(set(metrics.gen_ol))) 194 | self.log( 195 | f"{split}/unique_topo_mc", len(set(zip(metrics.gen_topo, metrics.gen_mc))) 196 | ) 197 | self.log(f"{split}/scaffold", metrics.get_mean(metrics.scaffold)) 198 | self.log(f"{split}/num_fail", metrics.get_mean(metrics.num_fail)) 199 | 200 | # add image to log 201 | # gen_ol with frags (32 images) 202 | for i in range(32): 203 | idx = random.Random(i).choice(range(len(metrics.gen_ol))) 204 | ol = metrics.gen_ol[idx] 205 | frags = metrics.input_frags[idx] 206 | imgs = [] 207 | for s in [ol] + frags: 208 | m = Chem.MolFromSmiles(s) 209 | if not m: 210 | continue 211 | img = MolToImage(m) 212 | img = np.array(img) 213 | img = torch.tensor(img) 214 | imgs.append(img) 215 | imgs = np.stack(imgs, axis=0) 216 | self.logger.experiment.add_image( 217 | f"{split}/{i}", imgs, self.global_step, dataformats="NHWC" 218 | ) 219 | 220 | # total gen_ol 221 | imgs = [] 222 | for i, m in enumerate(metrics.gen_ol[:32]): 223 | try: 224 | m = Chem.MolFromSmiles(m) 225 | img = MolToImage(m) 226 | img = np.array(img) 227 | img = torch.tensor(img) 228 | imgs.append(img) 229 | except Exception as e: 230 | print(e) 231 | imgs = np.stack(imgs, axis=0) 232 | self.logger.experiment.add_image( 233 | f"{split}/gen_ol/", imgs, self.global_step, dataformats="NHWC" 234 | ) 235 | 236 | def configure_optimizers(self): 237 | return module_utils.set_schedule(self) 238 | -------------------------------------------------------------------------------- /mofreinforce/generator/objectives.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | from torchmetrics.functional import accuracy 3 | 4 | 5 | def compute_loss(pl_module, batch): 6 | infer = pl_module.infer(batch) 7 | 8 | tgt_label = infer["tgt_label"] # [B, seq_len] 9 | 10 | batch_size, _, vocab_dim = infer["output_ol"].shape 11 | 12 | # loss topo 13 | logit_topo = infer["output_topo"] 14 | label_topo = tgt_label[:, 0] 15 | loss_topo = F.cross_entropy(logit_topo, label_topo) 16 | # loss mc 17 | logit_mc = infer["output_mc"] 18 | label_mc = tgt_label[:, 1] 19 | loss_mc = F.cross_entropy(logit_mc, label_mc) 20 | # loss ol 21 | logit_ol = infer["output_ol"].reshape(-1, vocab_dim) 22 | label_ol = tgt_label[:, 2:].reshape(-1) 23 | loss_ol = F.cross_entropy(logit_ol, label_ol, ignore_index=0) 24 | # total loss 25 | total_loss = loss_topo + loss_mc + loss_ol 26 | 27 | ret = { 28 | "gen_loss": total_loss, 29 | "gen_labels": tgt_label, 30 | } 31 | 32 | # call update() loss and acc 33 | loss_name = "generator" 34 | phase = "train" if pl_module.training else "val" 35 | total_loss = getattr(pl_module, f"{phase}_{loss_name}_loss")(ret["gen_loss"]) 36 | 37 | # acc 38 | acc_topo = getattr(pl_module, f"{phase}_{loss_name}_acc_topo")( 39 | accuracy(logit_topo.argmax(-1), label_topo) 40 | ) 41 | acc_mc = getattr(pl_module, f"{phase}_{loss_name}_acc_mc")( 42 | accuracy(logit_mc.argmax(-1), label_mc) 43 | ) 44 | acc_ol = getattr(pl_module, f"{phase}_{loss_name}_acc_ol")( 45 | accuracy(logit_ol.argmax(-1), label_ol, ignore_index=0) 46 | ) 47 | 48 | pl_module.log( 49 | f"{loss_name}/{phase}/total_loss", 50 | total_loss, 51 | batch_size=batch_size, 52 | prog_bar=True, 53 | sync_dist=True, 54 | ) 55 | pl_module.log( 56 | f"{loss_name}/{phase}/loss_topo", 57 | loss_topo, 58 | batch_size=batch_size, 59 | prog_bar=False, 60 | sync_dist=True, 61 | ) 62 | pl_module.log( 63 | f"{loss_name}/{phase}/loss_mc", 64 | loss_mc, 65 | batch_size=batch_size, 66 | prog_bar=False, 67 | sync_dist=True, 68 | ) 69 | pl_module.log( 70 | f"{loss_name}/{phase}/loss_ol", 71 | loss_ol, 72 | batch_size=batch_size, 73 | prog_bar=False, 74 | sync_dist=True, 75 | ) 76 | pl_module.log( 77 | f"{loss_name}/{phase}/acc_topo", 78 | acc_topo, 79 | batch_size=batch_size, 80 | prog_bar=True, 81 | sync_dist=True, 82 | ) 83 | pl_module.log( 84 | f"{loss_name}/{phase}/acc_mc", 85 | acc_mc, 86 | batch_size=batch_size, 87 | prog_bar=True, 88 | sync_dist=True, 89 | ) 90 | pl_module.log( 91 | f"{loss_name}/{phase}/acc_ol", 92 | acc_ol, 93 | batch_size=batch_size, 94 | prog_bar=True, 95 | sync_dist=True, 96 | ) 97 | return ret 98 | -------------------------------------------------------------------------------- /mofreinforce/generator/transformer/__init__.py: -------------------------------------------------------------------------------- 1 | from .transformer import Transformer, Encoder, Decoder -------------------------------------------------------------------------------- /mofreinforce/generator/transformer/attention.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | def attention(query, key, value, mask=None, dropout=None): 7 | """ 8 | 9 | :param query: [B, num_heads, seq_len, hid_dim//num_heads] 10 | :param key: [B, num_heads, seq_len, hid_dim//num_heads] 11 | :param value: [B, num_heads, seq_len, hid_dim//num_heads] 12 | :param mask: [B, 1, 1, seq_len] 13 | :param dropout: (float) dropout_rate 14 | :return: [B, num_heads, seq_len, hid_dim//num_heads] 15 | """ 16 | "Compute 'Scaled Dot Product Attention'" 17 | d_k = query.size(-1) # hid_dim 18 | scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k) # [B, num_heads, seq_len, seq_len] 19 | if mask is not None: 20 | scores = scores.masked_fill(mask == 0, -float('inf')) # [B, num_heads, seq_len, seq_len] 21 | 22 | p_attn = scores.softmax(dim=-1) # [B, num_heads, seq_len, seq_len] 23 | 24 | if dropout is not None: 25 | p_attn = dropout(p_attn) 26 | return torch.matmul(p_attn, value), p_attn # [B, num_heads, seq_len, hid_dim//num_heads] 27 | 28 | 29 | class MultiHeadedAttention(nn.Module): 30 | def __init__(self, h, d_model, dropout_rate=0.1): 31 | super(MultiHeadedAttention, self).__init__() 32 | assert d_model % h == 0 33 | # We assume d_v always equals d_k 34 | self.d_k = d_model // h 35 | self.h = h 36 | self.layer_q = nn.Linear(d_model, d_model) 37 | self.layer_k = nn.Linear(d_model, d_model) 38 | self.layer_v = nn.Linear(d_model, d_model) 39 | self.proj = nn.Linear(d_model, d_model) 40 | 41 | self.dropout = nn.Dropout(p=dropout_rate) 42 | 43 | def forward(self, q, k, v, mask=None): 44 | # query : [B, seq_len, hid_dim] 45 | # mask : [B, 1, seq_len] for src, [B, seq_len, seq_len] for tgt 46 | if mask is not None: 47 | mask = mask.unsqueeze(1) # [B, 1, 1, seq_len] 48 | batch_size = q.size(0) 49 | 50 | # [B, seq_len, num_heads, hid_dim//num_heads] -> [B, num_heads, seq_len, hid_dim//num_heads] 51 | q = self.layer_q(q).view(batch_size, -1, self.h, self.d_k).transpose(1, 2) 52 | k = self.layer_k(k).view(batch_size, -1, self.h, self.d_k).transpose(1, 2) 53 | v = self.layer_k(v).view(batch_size, -1, self.h, self.d_k).transpose(1, 2) 54 | 55 | # 2) Apply attention on all the projected vectors in batch. 56 | x, p_attn = attention( 57 | q, k, v, mask=mask, dropout=self.dropout 58 | ) # [B, num_heads, seq_len, hid_dim//num_heads] 59 | 60 | # 3) "Concat" using a view and apply a final linear. 61 | x = ( 62 | x.transpose(1, 2) # [B, seq_len, num_heads, hid_dim//num_heads] 63 | .contiguous() 64 | .view(batch_size, -1, self.h * self.d_k) # [B, seq_len, hid_dim] 65 | ) 66 | 67 | x = self.proj(x) 68 | return x, p_attn 69 | -------------------------------------------------------------------------------- /mofreinforce/generator/transformer/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | class MultiHeadAttentionLayer(nn.Module): 5 | def __init__(self, hid_dim, n_heads, dropout): 6 | super().__init__() 7 | 8 | assert hid_dim % n_heads == 0 9 | 10 | self.hid_dim = hid_dim 11 | self.n_heads = n_heads 12 | self.head_dim = hid_dim // n_heads 13 | 14 | self.fc_q = nn.Linear(hid_dim, hid_dim) 15 | self.fc_k = nn.Linear(hid_dim, hid_dim) 16 | self.fc_v = nn.Linear(hid_dim, hid_dim) 17 | 18 | self.fc_o = nn.Linear(hid_dim, hid_dim) 19 | 20 | self.dropout = nn.Dropout(dropout) 21 | 22 | self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])) 23 | 24 | def forward(self, query, key, value, mask=None): 25 | """ 26 | :param query: [B, seq_len, hid_dim] 27 | :param key: [B, seq_len, hid_dim] 28 | :param value: [B, seq_len, hid_dim] 29 | :param mask: [B, 1, 1, src_len] for src_mask, [B, 1, trg_len, trg_len] for trg_mask 30 | :return: [B, seq_len, hid_dim], [B, n_heads, seq_len, seq_len] 31 | """ 32 | batch_size = query.shape[0] 33 | 34 | Q = self.fc_q(query) # [batch size, query len, hid dim] 35 | K = self.fc_k(key) # [batch size, key len, hid dim] 36 | V = self.fc_v(value) # [batch size, value len, hid dim] 37 | 38 | Q = Q.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3) 39 | K = K.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3) 40 | V = V.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3) 41 | # [batch size, n heads, query len, head dim] 42 | # [batch size, n heads, key len, head dim] 43 | # [batch size, n heads, value len, head dim] 44 | energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale.to(query.device) 45 | # energy = [batch size, n heads, query len, key len] 46 | if mask is not None: 47 | energy = energy.masked_fill(mask == 0, -1e10) 48 | attention = torch.softmax(energy, dim=-1) # [batch size, n heads, query len, key len] 49 | x = torch.matmul(self.dropout(attention), V) # [batch size, n heads, query len, head dim] 50 | x = x.permute(0, 2, 1, 3).contiguous() # [batch size, query len, n heads, head dim] 51 | x = x.view(batch_size, -1, self.hid_dim) # [batch size, query len, hid dim] 52 | x = self.fc_o(x) # [batch size, query len, hid dim] 53 | return x, attention 54 | 55 | 56 | class PositionwiseFeedforwardLayer(nn.Module): 57 | def __init__(self, hid_dim, pf_dim, dropout): 58 | super().__init__() 59 | 60 | self.fc_1 = nn.Linear(hid_dim, pf_dim) 61 | self.fc_2 = nn.Linear(pf_dim, hid_dim) 62 | 63 | self.dropout = nn.Dropout(dropout) 64 | 65 | def forward(self, x): 66 | """ 67 | :param x: [B, seq_len, hid dim] 68 | :return: [B, seq_len, hid_dim] 69 | """ 70 | x = self.dropout(torch.relu(self.fc_1(x))) # [batch size, seq len, pf dim] 71 | x = self.fc_2(x) # [batch size, seq len, hid dim] 72 | return x 73 | 74 | 75 | class EncoderLayer(nn.Module): 76 | def __init__(self, 77 | hid_dim, 78 | n_heads, 79 | pf_dim, 80 | dropout, 81 | ): 82 | super().__init__() 83 | 84 | self.self_attn_layer_norm = nn.LayerNorm(hid_dim) 85 | self.ff_layer_norm = nn.LayerNorm(hid_dim) 86 | self.self_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout) 87 | self.positionwise_feedforward = PositionwiseFeedforwardLayer(hid_dim, 88 | pf_dim, 89 | dropout) 90 | self.dropout = nn.Dropout(dropout) 91 | 92 | def forward(self, src, src_mask): 93 | """ 94 | :param src: [B, src_len, hid_dim] 95 | :param src_mask: [B, 1, 1, src_len] 96 | :return: [B, src_len, hid_dim] 97 | """ 98 | # self attention 99 | _src, _ = self.self_attention(src, src, src, src_mask) # [batch_size, src_len, hid_dim] 100 | 101 | # dropout, residual connection and layer norm 102 | src = self.self_attn_layer_norm(src + self.dropout(_src)) # [batch_size, src_len, hid_dim] 103 | 104 | # positionwise feedforward 105 | _src = self.positionwise_feedforward(src) # [batch_size, src_len, hid_dim] 106 | 107 | # dropout, residual and layer norm 108 | src = self.ff_layer_norm(src + self.dropout(_src)) # [batch_size, src_len, hid_dim] 109 | return src 110 | 111 | 112 | class DecoderLayer(nn.Module): 113 | def __init__(self, 114 | hid_dim, 115 | n_heads, 116 | pf_dim, 117 | dropout, 118 | ): 119 | super().__init__() 120 | 121 | self.self_attn_layer_norm = nn.LayerNorm(hid_dim) 122 | self.enc_attn_layer_norm = nn.LayerNorm(hid_dim) 123 | self.ff_layer_norm = nn.LayerNorm(hid_dim) 124 | self.self_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout) 125 | self.encoder_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout) 126 | self.positionwise_feedforward = PositionwiseFeedforwardLayer(hid_dim, 127 | pf_dim, 128 | dropout) 129 | self.dropout = nn.Dropout(dropout) 130 | 131 | def forward(self, trg, enc_src, trg_mask, src_mask): 132 | """ 133 | :param trg: [B, trg_len, hid_dim] 134 | :param enc_src: [B, src_len, hid_dim] 135 | :param trg_mask: [B, 1, trg_len, trg_len] 136 | :param src_mask: [B, 1, 1, seq_len] 137 | :return: [B, trg_len, hid_dim], [B, n_heads, trg_len, src_len] 138 | """ 139 | 140 | # self attention 141 | _trg, _ = self.self_attention(trg, trg, trg, trg_mask) # [batch size, trg len, hid dim] 142 | 143 | # dropout, residual connection and layer norm 144 | trg = self.self_attn_layer_norm(trg + self.dropout(_trg)) # [batch size, trg len, hid dim] 145 | 146 | # encoder attention 147 | _trg, attention = self.encoder_attention(trg, enc_src, enc_src, src_mask) 148 | # [batch size, trg len, hid dim], [batch size, n heads, trg len, src len] 149 | 150 | # dropout, residual connection and layer norm 151 | trg = self.enc_attn_layer_norm(trg + self.dropout(_trg)) # [batch size, trg len, hid dim] 152 | 153 | # positionwise feedforward 154 | _trg = self.positionwise_feedforward(trg) # [batch size, trg len, hid dim] 155 | 156 | # dropout, residual and layer norm 157 | trg = self.ff_layer_norm(trg + self.dropout(_trg)) # [batch size, trg len, hid dim] 158 | 159 | return trg, attention -------------------------------------------------------------------------------- /mofreinforce/generator/transformer/transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from .layers import EncoderLayer, DecoderLayer 5 | 6 | 7 | class Transformer(nn.Module): 8 | """ 9 | Trnasformer for MOF Generator 10 | """ 11 | def __init__(self, 12 | input_dim, 13 | output_dim, 14 | topo_dim, 15 | mc_dim, 16 | hid_dim, 17 | n_layers, 18 | n_heads, 19 | pf_dim, 20 | dropout, 21 | max_len, 22 | src_pad_idx, 23 | trg_pad_idx, 24 | ): 25 | super().__init__() 26 | 27 | self.encoder = Encoder( 28 | input_dim=input_dim, 29 | mc_dim=mc_dim, 30 | hid_dim=hid_dim, 31 | n_layers=n_layers, 32 | n_heads=n_heads, 33 | pf_dim=pf_dim, 34 | dropout=dropout, 35 | max_len=max_len, 36 | ) 37 | self.decoder = Decoder( 38 | output_dim=output_dim, 39 | topo_dim=topo_dim, 40 | mc_dim=mc_dim, 41 | hid_dim=hid_dim, 42 | n_layers=n_layers, 43 | n_heads=n_heads, 44 | pf_dim=pf_dim, 45 | dropout=dropout, 46 | max_len=max_len, 47 | ) 48 | self.src_pad_idx = src_pad_idx 49 | self.trg_pad_idx = trg_pad_idx 50 | 51 | self.apply(self.init_weights) 52 | 53 | def init_weights(self, m): 54 | if hasattr(m, 'weight') and m.weight.dim() > 1: 55 | nn.init.xavier_uniform_(m.weight.data) 56 | 57 | def make_src_mask(self, src): 58 | """ 59 | make padding mask for src 60 | :param src: [B, src_len] 61 | :return: [B, 1, 1, src_len] 62 | """ 63 | src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2) 64 | return src_mask 65 | 66 | def make_trg_mask(self, trg): 67 | """ 68 | make padding and look-ahead mask for trg 69 | :param trg: [B, trg_len] 70 | :return: [B, 1, trg_len, trg_len] 71 | """ 72 | trg_pad_mask = (trg != self.trg_pad_idx).unsqueeze(1).unsqueeze(2) # [batch size, 1, 1, trg len] 73 | 74 | trg_len = trg.shape[1] 75 | trg_sub_mask = torch.tril(torch.ones((trg_len, trg_len))).bool() # [trg len, trg len] 76 | 77 | trg_mask = trg_pad_mask & trg_sub_mask.to(trg_pad_mask.device) # [batch size, 1, trg len, trg len] 78 | 79 | return trg_mask 80 | 81 | def forward(self, src, trg): 82 | """ 83 | :param src: [B, src_len] 84 | :param trg: [B, trg_len] 85 | :return: [B, trg_len, vocab_dim] [B, n_heads, trg_len, src_len] 86 | """ 87 | src_mask = self.make_src_mask(src) # [batch size, 1, 1, src len] 88 | trg_mask = self.make_trg_mask(trg) # [batch size, 1, trg len, trg len] 89 | 90 | # encoder 91 | enc_src = self.encoder(src, src_mask) # [batch size, src len, hid dim] 92 | 93 | # decoder 94 | output = self.decoder(trg, enc_src, trg_mask, src_mask) 95 | # [batch size, trg len, output dim], [batch size, n heads, trg len, src len] 96 | return output 97 | 98 | 99 | class Encoder(nn.Module): 100 | def __init__(self, 101 | input_dim, 102 | mc_dim, 103 | hid_dim, 104 | n_layers, 105 | n_heads, 106 | pf_dim, 107 | dropout, 108 | max_len, 109 | max_conn=10): 110 | super().__init__() 111 | self.mc_embedding = nn.Embedding(mc_dim, hid_dim) 112 | self.num_embedding = nn.Embedding(max_conn, hid_dim) # num_conn of ol 113 | self.vocab_embedding = nn.Embedding(input_dim, hid_dim) 114 | self.pos_embedding = nn.Embedding(max_len, hid_dim) 115 | 116 | self.layers = nn.ModuleList([EncoderLayer(hid_dim, 117 | n_heads, 118 | pf_dim, 119 | dropout, 120 | ) 121 | for _ in range(n_layers)]) 122 | 123 | self.dropout = nn.Dropout(dropout) 124 | 125 | self.scale = torch.sqrt(torch.FloatTensor([hid_dim])) 126 | 127 | def forward(self, src, src_mask): 128 | """ 129 | :param src: [B, src_len] 130 | :param src_mask: [B, 1, 1, src_len] 131 | :param num_conn: [B] 132 | :return: [B, src_len, hid_dim] 133 | """ 134 | batch_size = src.shape[0] 135 | src_len = src.shape[1] 136 | 137 | pos = torch.arange(0, src_len).unsqueeze(0).repeat(batch_size, 1).to(src.device) # [batch size, src len] 138 | 139 | src = torch.concat( 140 | [ 141 | self.mc_embedding(src[:, 0].unsqueeze(-1)), 142 | self.num_embedding(src[:, 1].unsqueeze(-1)), 143 | self.vocab_embedding(src[:, 2:]), # [batch size, src_len, hid dim] 144 | ], 145 | dim=1 146 | ) 147 | 148 | src = self.dropout((src * self.scale.to(src.device)) + self.pos_embedding(pos)) 149 | # [batch size, src len, hid dim] 150 | 151 | for layer in self.layers: 152 | src = layer(src, src_mask) # [batch size, src len, hid dim] 153 | 154 | return src 155 | 156 | 157 | class Decoder(nn.Module): 158 | def __init__(self, 159 | output_dim, 160 | topo_dim, 161 | mc_dim, 162 | hid_dim, 163 | n_layers, 164 | n_heads, 165 | pf_dim, 166 | dropout, 167 | max_len, 168 | ): 169 | super().__init__() 170 | 171 | self.topo_embedding = nn.Embedding(topo_dim, hid_dim) 172 | self.mc_embedding = nn.Embedding(mc_dim, hid_dim) 173 | self.vocab_embedding = nn.Embedding(output_dim, hid_dim) 174 | self.pos_embedding = nn.Embedding(max_len, hid_dim) 175 | 176 | self.layers = nn.ModuleList([DecoderLayer(hid_dim, 177 | n_heads, 178 | pf_dim, 179 | dropout, 180 | ) 181 | for _ in range(n_layers)]) 182 | 183 | self.fc_out_topo = nn.Linear(hid_dim, topo_dim) 184 | self.fc_out_mc = nn.Linear(hid_dim, mc_dim) 185 | self.fc_out_ol = nn.Linear(hid_dim, output_dim) 186 | 187 | self.dropout = nn.Dropout(dropout) 188 | 189 | self.scale = torch.sqrt(torch.FloatTensor([hid_dim])) 190 | 191 | def forward(self, trg, enc_src, trg_mask, src_mask): 192 | """ 193 | if len(trg) == 1: topo_embedding 194 | elif len(trg) == 2: mc_embedding 195 | else: vocab_embedding 196 | :param trg: [B, trg_len] 197 | :param enc_src: [B, src_len, hid_dim] 198 | :param trg_mask: [B, 1, trg_len, trg_len] 199 | :param src_mask: [B, 1, 1, src_len] 200 | :return: [B, trg_len, vocab_dim] [B, n_heads, trg_len, src_len] 201 | """ 202 | 203 | batch_size = trg.shape[0] 204 | trg_len = trg.shape[1] 205 | 206 | pos = torch.arange(0, trg_len).unsqueeze(0).repeat(batch_size, 1).to(trg.device) # [batch size, trg len] 207 | 208 | if trg_len == 1: # [SOS] 209 | tok_emb = self.vocab_embedding(trg[:, 0].unsqueeze(-1)) 210 | elif trg_len == 2: # [SOS, topo] 211 | tok_emb = torch.concat( 212 | [ 213 | self.vocab_embedding(trg[:, 0].unsqueeze(-1)), 214 | self.topo_embedding(trg[:, 1].unsqueeze(-1)), 215 | ], 216 | dim=1 217 | ) 218 | elif trg_len == 3: # [SOS, topo, mc] 219 | tok_emb = torch.concat( 220 | [ 221 | self.vocab_embedding(trg[:, 0].unsqueeze(-1)), 222 | self.topo_embedding(trg[:, 1].unsqueeze(-1)), 223 | self.mc_embedding(trg[:, 2].unsqueeze(-1)), 224 | ], 225 | dim=1 226 | ) 227 | else: # [SOS, topo, mc, ol] 228 | tok_emb = torch.concat( 229 | [ 230 | self.vocab_embedding(trg[:, 0].unsqueeze(-1)), 231 | self.topo_embedding(trg[:, 1].unsqueeze(-1)), 232 | self.mc_embedding(trg[:, 2].unsqueeze(-1)), 233 | self.vocab_embedding(trg[:, 3:]) 234 | ], 235 | dim=1 236 | ) 237 | 238 | # [batch size, trg len, hid dim] 239 | trg = self.dropout((tok_emb * self.scale.to(trg.device)) + self.pos_embedding(pos)) 240 | 241 | for layer in self.layers: 242 | trg, attention = layer(trg, enc_src, trg_mask, src_mask) 243 | # [batch size, trg len, hid dim], [batch size, n heads, trg len, src len] 244 | 245 | output = {} 246 | if trg_len == 1: 247 | output.update( 248 | { 249 | "output_topo" : self.fc_out_topo(trg[:, 0]), # [batch size, topo_dim] 250 | } 251 | ) 252 | elif trg_len == 2: 253 | output.update( 254 | { 255 | "output_topo" : self.fc_out_topo(trg[:, 0]), # [batch size, topo_dim], 256 | "output_mc" : self.fc_out_mc(trg[:, 1]), # [batch size, mc_dim] 257 | } 258 | ) 259 | else: 260 | output.update( 261 | { 262 | "output_topo" : self.fc_out_topo(trg[:, 0]), # [batch size, topo_dim], 263 | "output_mc" : self.fc_out_mc(trg[:, 1]), # [batch size, mc_dim] 264 | "output_ol" : self.fc_out_ol(trg[:, 2:]) # [batch size, trg len, output dim] 265 | } 266 | ) 267 | output.update({"attention" : attention}) 268 | return output 269 | -------------------------------------------------------------------------------- /mofreinforce/libs/selfies/README.md: -------------------------------------------------------------------------------- 1 | This SELFIES directory is adapted from the "selfies" repository: https://github.com/aspuru-guzik-group/selfies.git 2 | 3 | The following modifications have been made to the original repository to recognize dummy atoms "*" in SMILES, 4 | which were not provided for in the official "selfies" repository: 5 | 6 | (1) In selfies.utils.smiles_utils.py, the following lines have been added to line 91: 7 | ```python 8 | elif smiles[i] == "*": 9 | token = SMILESToken(bond_idx, i, i + 1, 10 | SMILESTokenTypes.ATOM, smiles[i:i + 1]) 11 | ``` 12 | 13 | (2) The dummy atom "*" has been added to `ORGANIC_SUBSET` in `selfies.constants.py`: 14 | ```python 15 | ORGANIC_SUBSET = {"*", "B", "C", "N", "O", "S", "P", "F", "Cl", "Br", "I"} 16 | ``` 17 | 18 | (3) "*" has been added to the regex rule in line 110 of `selfies.grammar_rules.py`: 19 | ```python 20 | r"([A-Z][a-z]?|\*)" # element symbol 21 | ``` 22 | -------------------------------------------------------------------------------- /mofreinforce/libs/selfies/__init__.py: -------------------------------------------------------------------------------- 1 | # This code is adapted from the "selfies" repository: 2 | # https://github.com/aspuru-guzik-group/selfies.git 3 | 4 | # The code in this file is licensed under the Apache License, Version 2.0: 5 | # https://github.com/aspuru-guzik-group/selfies/blob/master/LICENSE 6 | 7 | # Adapted by Hyunsoo Park, 2022 8 | 9 | #!/usr/bin/env python 10 | 11 | """ 12 | SELFIES: a robust representation of semantically constrained graphs with an 13 | example application in chemistry. 14 | 15 | SELFIES (SELF-referencIng Embedded Strings) is a general-purpose, 16 | sequence-based, robust representation of semantically constrained graphs. 17 | It is based on a Chomsky type-2 grammar, augmented with two self-referencing 18 | functions. A main objective is to use SELFIES as direct input into machine 19 | learning models, in particular in generative models, for the generation of 20 | outputs with high validity. 21 | 22 | The code presented here is a concrete application of SELFIES in chemistry, for 23 | the robust representation of molecules. 24 | 25 | Typical usage example: 26 | import selfies as sf 27 | 28 | benzene = "C1=CC=CC=C1" 29 | benzene_selfies = sf.encoder(benzene) 30 | benzene_smiles = sf.decoder(benzene_selfies) 31 | 32 | For comments, bug reports or feature ideas, please send an email to 33 | mario.krenn@utoronto.ca and alan@aspuru.com. 34 | """ 35 | 36 | __version__ = "2.1.0" 37 | 38 | __all__ = [ 39 | "encoder", 40 | "decoder", 41 | "get_preset_constraints", 42 | "get_semantic_robust_alphabet", 43 | "get_semantic_constraints", 44 | "set_semantic_constraints", 45 | "len_selfies", 46 | "split_selfies", 47 | "get_alphabet_from_selfies", 48 | "selfies_to_encoding", 49 | "batch_selfies_to_flat_hot", 50 | "encoding_to_selfies", 51 | "batch_flat_hot_to_selfies", 52 | "EncoderError", 53 | "DecoderError" 54 | ] 55 | 56 | from .bond_constraints import ( 57 | get_preset_constraints, 58 | get_semantic_constraints, 59 | get_semantic_robust_alphabet, 60 | set_semantic_constraints 61 | ) 62 | from .decoder import decoder 63 | from .encoder import encoder 64 | from .exceptions import DecoderError, EncoderError 65 | from .utils.encoding_utils import ( 66 | batch_flat_hot_to_selfies, 67 | batch_selfies_to_flat_hot, 68 | encoding_to_selfies, 69 | selfies_to_encoding 70 | ) 71 | from .utils.selfies_utils import ( 72 | get_alphabet_from_selfies, 73 | len_selfies, 74 | split_selfies 75 | ) 76 | -------------------------------------------------------------------------------- /mofreinforce/libs/selfies/bond_constraints.py: -------------------------------------------------------------------------------- 1 | # This code is adapted from the "selfies" repository: 2 | # https://github.com/aspuru-guzik-group/selfies.git 3 | 4 | # The code in this file is licensed under the Apache License, Version 2.0: 5 | # https://github.com/aspuru-guzik-group/selfies/blob/master/LICENSE 6 | 7 | # Adapted by Hyunsoo Park, 2022 8 | 9 | import functools 10 | from itertools import product 11 | from typing import Dict, Set, Union 12 | 13 | from .constants import ELEMENTS, INDEX_ALPHABET 14 | 15 | _DEFAULT_CONSTRAINTS = { 16 | "H": 1, "F": 1, "Cl": 1, "Br": 1, "I": 1, 17 | "B": 3, "B+1": 2, "B-1": 4, 18 | "O": 2, "O+1": 3, "O-1": 1, 19 | "N": 3, "N+1": 4, "N-1": 2, 20 | "C": 4, "C+1": 5, "C-1": 3, 21 | "P": 5, "P+1": 6, "P-1": 4, 22 | "S": 6, "S+1": 7, "S-1": 5, 23 | "?": 8 24 | } 25 | 26 | _PRESET_CONSTRAINTS = { 27 | "default": dict(_DEFAULT_CONSTRAINTS), 28 | "octet_rule": dict(_DEFAULT_CONSTRAINTS), 29 | "hypervalent": dict(_DEFAULT_CONSTRAINTS) 30 | } 31 | _PRESET_CONSTRAINTS["octet_rule"].update( 32 | {"S": 2, "S+1": 3, "S-1": 1, "P": 3, "P+1": 4, "P-1": 2} 33 | ) 34 | _PRESET_CONSTRAINTS["hypervalent"].update( 35 | {"Cl": 7, "Br": 7, "I": 7, "N": 5} 36 | ) 37 | 38 | _current_constraints = _PRESET_CONSTRAINTS["default"] 39 | 40 | 41 | def get_preset_constraints(name: str) -> Dict[str, int]: 42 | """Returns the preset semantic constraints with the given name. 43 | 44 | Besides the aforementioned default constraints, :mod:`selfies` offers 45 | other preset constraints for convenience; namely, constraints that 46 | enforce the `octet rule `_ 47 | and constraints that accommodate `hypervalent molecules 48 | `_. 49 | 50 | The differences between these constraints can be summarized as follows: 51 | 52 | .. table:: 53 | :align: center 54 | :widths: auto 55 | 56 | +-----------------+-----------+---+---+-----+-----+---+-----+-----+ 57 | | | Cl, Br, I | N | P | P+1 | P-1 | S | S+1 | S-1 | 58 | +-----------------+-----------+---+---+-----+-----+---+-----+-----+ 59 | | ``default`` | 1 | 3 | 5 | 6 | 4 | 6 | 7 | 5 | 60 | +-----------------+-----------+---+---+-----+-----+---+-----+-----+ 61 | | ``octet_rule`` | 1 | 3 | 3 | 4 | 2 | 2 | 3 | 1 | 62 | +-----------------+-----------+---+---+-----+-----+---+-----+-----+ 63 | | ``hypervalent`` | 7 | 5 | 5 | 6 | 4 | 6 | 7 | 5 | 64 | +-----------------+-----------+---+---+-----+-----+---+-----+-----+ 65 | 66 | :param name: the preset name: ``default`` or ``octet_rule`` or 67 | ``hypervalent``. 68 | :return: the preset constraints with the specified name, represented 69 | as a dictionary which maps atoms (the keys) to their bonding capacities 70 | (the values). 71 | """ 72 | 73 | if name not in _PRESET_CONSTRAINTS: 74 | raise ValueError("unrecognized preset name '{}'".format(name)) 75 | return dict(_PRESET_CONSTRAINTS[name]) 76 | 77 | 78 | def get_semantic_constraints() -> Dict[str, int]: 79 | """Returns the semantic constraints that :mod:`selfies` is currently 80 | operating on. 81 | 82 | :return: the current semantic constraints, represented as a dictionary 83 | which maps atoms (the keys) to their bonding capacities (the values). 84 | """ 85 | 86 | global _current_constraints 87 | return dict(_current_constraints) 88 | 89 | 90 | def set_semantic_constraints( 91 | bond_constraints: Union[str, Dict[str, int]] = "default" 92 | ) -> None: 93 | """Updates the semantic constraints that :mod:`selfies` operates on. 94 | 95 | If the input is a string, the new constraints are taken to be 96 | the preset named ``bond_constraints`` 97 | (see :func:`selfies.get_preset_constraints`). 98 | 99 | Otherwise, the input is a dictionary representing the new constraints. 100 | This dictionary maps atoms (the keys) to non-negative bonding 101 | capacities (the values); the atoms are specified by strings 102 | of the form ``E`` or ``E+C`` or ``E-C``, 103 | where ``E`` is an element symbol and ``C`` is a positive integer. 104 | For example, one may have: 105 | 106 | * ``bond_constraints["I-1"] = 0`` 107 | * ``bond_constraints["C"] = 4`` 108 | 109 | This dictionary must also contain the special ``?`` key, which indicates 110 | the bond capacities of all atoms that are not explicitly listed 111 | in the dictionary. 112 | 113 | :param bond_constraints: the name of a preset, or a dictionary 114 | representing the new semantic constraints. 115 | :return: ``None``. 116 | """ 117 | 118 | global _current_constraints 119 | 120 | if isinstance(bond_constraints, str): 121 | _current_constraints = get_preset_constraints(bond_constraints) 122 | 123 | elif isinstance(bond_constraints, dict): 124 | 125 | # error checking 126 | if "?" not in bond_constraints: 127 | raise ValueError("bond_constraints missing '?' as a key") 128 | 129 | for key, value in bond_constraints.items(): 130 | 131 | # error checking for keys 132 | j = max(key.find("+"), key.find("-")) 133 | if key == "?": 134 | valid = True 135 | elif j == -1: 136 | valid = (key in ELEMENTS) 137 | else: 138 | valid = (key[:j] in ELEMENTS) and key[j + 1:].isnumeric() 139 | if not valid: 140 | err_msg = "invalid key '{}' in bond_constraints".format(key) 141 | raise ValueError(err_msg) 142 | 143 | # error checking for values 144 | if not (isinstance(value, int) and value >= 0): 145 | err_msg = "invalid value at " \ 146 | "bond_constraints['{}'] = {}".format(key, value) 147 | raise ValueError(err_msg) 148 | 149 | _current_constraints = dict(bond_constraints) 150 | 151 | else: 152 | raise ValueError("bond_constraints must be a str or dict") 153 | 154 | # clear cache since we changed alphabet 155 | get_semantic_robust_alphabet.cache_clear() 156 | get_bonding_capacity.cache_clear() 157 | 158 | 159 | @functools.lru_cache() 160 | def get_semantic_robust_alphabet() -> Set[str]: 161 | """Returns a subset of all SELFIES symbols that are constrained 162 | by :mod:`selfies` under the current semantic constraints. 163 | 164 | :return: a subset of all SELFIES symbols that are semantically constrained. 165 | """ 166 | 167 | alphabet_subset = set() 168 | bonds = {"": 1, "=": 2, "#": 3} 169 | 170 | # add atomic symbols 171 | for (a, c), (b, m) in product(_current_constraints.items(), bonds.items()): 172 | if (m > c) or (a == "?"): 173 | continue 174 | symbol = "[{}{}]".format(b, a) 175 | alphabet_subset.add(symbol) 176 | 177 | # add branch and ring symbols 178 | for i in range(1, 4): 179 | alphabet_subset.add("[Ring{}]".format(i)) 180 | alphabet_subset.add("[=Ring{}]".format(i)) 181 | alphabet_subset.add("[Branch{}]".format(i)) 182 | alphabet_subset.add("[=Branch{}]".format(i)) 183 | alphabet_subset.add("[#Branch{}]".format(i)) 184 | 185 | alphabet_subset.update(INDEX_ALPHABET) 186 | 187 | return alphabet_subset 188 | 189 | 190 | @functools.lru_cache() 191 | def get_bonding_capacity(element: str, charge: int) -> int: 192 | """Returns the bonding capacity of a given atom, under the current 193 | semantic constraints. 194 | 195 | :param element: the element of the input atom. 196 | :param charge: the charge of the input atom. 197 | :return: the bonding capacity of the input atom. 198 | """ 199 | 200 | key = element 201 | if charge != 0: 202 | key += "{:+}".format(charge) 203 | 204 | if key in _current_constraints: 205 | return _current_constraints[key] 206 | else: 207 | return _current_constraints["?"] 208 | -------------------------------------------------------------------------------- /mofreinforce/libs/selfies/compatibility.py: -------------------------------------------------------------------------------- 1 | # This code is adapted from the "selfies" repository: 2 | # https://github.com/aspuru-guzik-group/selfies.git 3 | 4 | # The code in this file is licensed under the Apache License, Version 2.0: 5 | # https://github.com/aspuru-guzik-group/selfies/blob/master/LICENSE 6 | 7 | # Adapted by Hyunsoo Park, 2022 8 | 9 | from .utils.smiles_utils import atom_to_smiles, smiles_to_atom 10 | 11 | 12 | def modernize_symbol(symbol): 13 | """Converts a SELFIES symbol from \ 32 | Union[str, Tuple[str, List[Tuple[str, List[Tuple[int, str]]]]]]: 33 | """Translates a SELFIES string into its corresponding SMILES string. 34 | 35 | This translation is deterministic but depends on the current semantic 36 | constraints. The output SMILES string is guaranteed to be syntatically 37 | correct and guaranteed to represent a molecule that obeys the 38 | semantic constraints. 39 | 40 | :param selfies: the SELFIES string to be translated. 41 | :param compatible: if ``True``, this function will accept SELFIES strings 42 | containing depreciated symbols from previous releases. However, this 43 | function may behave differently than in previous major relases, 44 | and should not be treated as backard compatible. 45 | Defaults to ``False``. 46 | :param attribute: if ``True``, an attribution map connecting selfies 47 | tokens to smiles tokens is output. 48 | :return: a SMILES string derived from the input SELFIES string. 49 | :raises DecoderError: if the input SELFIES string is malformed. 50 | 51 | :Example: 52 | 53 | >>> import selfies as sf 54 | >>> sf.decoder('[C][=C][F]') 55 | 'C=CF' 56 | """ 57 | 58 | if compatible: 59 | msg = "\nselfies.decoder() may behave differently than in previous " \ 60 | "major releases. We recommend using SELFIES that are up to date." 61 | warnings.warn(msg, stacklevel=2) 62 | 63 | mol = MolecularGraph(attributable=attribute) 64 | 65 | rings = [] 66 | attribution_index = 0 67 | for s in selfies.split("."): 68 | n = _derive_mol_from_symbols( 69 | symbol_iter=enumerate(_tokenize_selfies(s, compatible)), 70 | mol=mol, 71 | selfies=selfies, 72 | max_derive=float("inf"), 73 | init_state=0, 74 | root_atom=None, 75 | rings=rings, 76 | attribute_stack=[] if attribute else None, 77 | attribution_index=attribution_index 78 | ) 79 | attribution_index += n 80 | _form_rings_bilocally(mol, rings) 81 | return mol_to_smiles(mol, attribute) 82 | 83 | 84 | def _tokenize_selfies(selfies, compatible): 85 | if isinstance(selfies, str): 86 | symbol_iter = split_selfies(selfies) 87 | elif isinstance(selfies, list): 88 | symbol_iter = selfies 89 | else: 90 | raise ValueError() # should not happen 91 | 92 | try: 93 | for symbol in symbol_iter: 94 | if symbol == "[nop]": 95 | continue 96 | if compatible: 97 | symbol = modernize_symbol(symbol) 98 | yield symbol 99 | except ValueError as err: 100 | raise DecoderError(str(err)) from None 101 | 102 | 103 | def _derive_mol_from_symbols( 104 | symbol_iter, mol, selfies, max_derive, 105 | init_state, root_atom, rings, attribute_stack, attribution_index 106 | ): 107 | n_derived = 0 108 | state = init_state 109 | prev_atom = root_atom 110 | 111 | while (state is not None) and (n_derived < max_derive): 112 | 113 | try: # retrieve next symbol 114 | index, symbol = next(symbol_iter) 115 | n_derived += 1 116 | except StopIteration: 117 | break 118 | 119 | # Case 1: Branch symbol (e.g. [Branch1]) 120 | if "ch" == symbol[-4:-2]: 121 | 122 | output = process_branch_symbol(symbol) 123 | if output is None: 124 | _raise_decoder_error(selfies, symbol) 125 | btype, n = output 126 | 127 | if state <= 1: 128 | next_state = state 129 | else: 130 | binit_state, next_state = next_branch_state(btype, state) 131 | 132 | Q = _read_index_from_selfies(symbol_iter, n_symbols=n) 133 | n_derived += n + _derive_mol_from_symbols( 134 | symbol_iter, mol, selfies, (Q + 1), 135 | init_state=binit_state, root_atom=prev_atom, rings=rings, 136 | attribute_stack=attribute_stack + 137 | [Attribution(index + attribution_index, symbol) 138 | ] if attribute_stack is not None else None, 139 | attribution_index=attribution_index 140 | ) 141 | 142 | # Case 2: Ring symbol (e.g. [Ring2]) 143 | elif "ng" == symbol[-4:-2]: 144 | 145 | output = process_ring_symbol(symbol) 146 | if output is None: 147 | _raise_decoder_error(selfies, symbol) 148 | ring_type, n, stereo = output 149 | 150 | if state == 0: 151 | next_state = state 152 | else: 153 | ring_order, next_state = next_ring_state(ring_type, state) 154 | bond_info = (ring_order, stereo) 155 | 156 | Q = _read_index_from_selfies(symbol_iter, n_symbols=n) 157 | n_derived += n 158 | lidx = max(0, prev_atom.index - (Q + 1)) 159 | rings.append((mol.get_atom(lidx), prev_atom, bond_info)) 160 | 161 | # Case 3: [epsilon] 162 | elif "eps" in symbol: 163 | next_state = 0 if (state == 0) else None 164 | 165 | # Case 4: regular symbol (e.g. [N], [=C], [F]) 166 | else: 167 | 168 | output = process_atom_symbol(symbol) 169 | if output is None: 170 | _raise_decoder_error(selfies, symbol) 171 | (bond_order, stereo), atom = output 172 | cap = atom.bonding_capacity 173 | 174 | bond_order, next_state = next_atom_state(bond_order, cap, state) 175 | if bond_order == 0: 176 | if state == 0: 177 | o = mol.add_atom(atom, True) 178 | mol.add_attribution( 179 | o, attribute_stack + 180 | [Attribution(index + attribution_index, symbol)] 181 | if attribute_stack is not None else None) 182 | else: 183 | o = mol.add_atom(atom) 184 | mol.add_attribution( 185 | o, attribute_stack + 186 | [Attribution(index + attribution_index, symbol)] 187 | if attribute_stack is not None else None) 188 | src, dst = prev_atom.index, atom.index 189 | o = mol.add_bond(src=src, dst=dst, 190 | order=bond_order, stereo=stereo) 191 | mol.add_attribution( 192 | o, attribute_stack + 193 | [Attribution(index + attribution_index, symbol)] 194 | if attribute_stack is not None else None) 195 | prev_atom = atom 196 | 197 | if next_state is None: 198 | break 199 | state = next_state 200 | 201 | while n_derived < max_derive: # consume remaining tokens 202 | try: 203 | next(symbol_iter) 204 | n_derived += 1 205 | except StopIteration: 206 | break 207 | 208 | return n_derived 209 | 210 | 211 | def _raise_decoder_error(selfies, invalid_symbol): 212 | err_msg = "invalid symbol '{}'\n\tSELFIES: {}".format( 213 | invalid_symbol, selfies 214 | ) 215 | raise DecoderError(err_msg) 216 | 217 | 218 | def _read_index_from_selfies(symbol_iter, n_symbols): 219 | index_symbols = [] 220 | for _ in range(n_symbols): 221 | try: 222 | index_symbols.append(next(symbol_iter)[-1]) 223 | except StopIteration: 224 | index_symbols.append(None) 225 | return get_index_from_selfies(*index_symbols) 226 | 227 | 228 | def _form_rings_bilocally(mol, rings): 229 | rings_made = [0] * len(mol) 230 | 231 | for latom, ratom, bond_info in rings: 232 | lidx, ridx = latom.index, ratom.index 233 | 234 | if lidx == ridx: # ring to the same atom forbidden 235 | continue 236 | 237 | order, (lstereo, rstereo) = bond_info 238 | lfree = latom.bonding_capacity - mol.get_bond_count(lidx) 239 | rfree = ratom.bonding_capacity - mol.get_bond_count(ridx) 240 | 241 | if lfree <= 0 or rfree <= 0: 242 | continue # no room for ring bond 243 | order = min(order, lfree, rfree) 244 | 245 | if mol.has_bond(a=lidx, b=ridx): 246 | bond = mol.get_dirbond(src=lidx, dst=ridx) 247 | new_order = min(order + bond.order, 3) 248 | mol.update_bond_order(a=lidx, b=ridx, new_order=new_order) 249 | 250 | else: 251 | mol.add_ring_bond( 252 | a=lidx, a_stereo=lstereo, a_pos=rings_made[lidx], 253 | b=ridx, b_stereo=rstereo, b_pos=rings_made[ridx], 254 | order=order 255 | ) 256 | rings_made[lidx] += 1 257 | rings_made[ridx] += 1 258 | -------------------------------------------------------------------------------- /mofreinforce/libs/selfies/encoder.py: -------------------------------------------------------------------------------- 1 | # This code is adapted from the "selfies" repository: 2 | # https://github.com/aspuru-guzik-group/selfies.git 3 | 4 | # The code in this file is licensed under the Apache License, Version 2.0: 5 | # https://github.com/aspuru-guzik-group/selfies/blob/master/LICENSE 6 | 7 | # Adapted by Hyunsoo Park, 2022 8 | 9 | from .exceptions import EncoderError, SMILESParserError 10 | from .grammar_rules import get_selfies_from_index 11 | from .utils.smiles_utils import ( 12 | atom_to_smiles, 13 | bond_to_smiles, 14 | smiles_to_mol 15 | ) 16 | 17 | from .mol_graph import AttributionMap 18 | 19 | 20 | def encoder(smiles: str, strict: bool = True, attribute: bool = False) -> str: 21 | """Translates a SMILES string into its corresponding SELFIES string. 22 | 23 | This translation is deterministic and does not depend on the 24 | current semantic constraints. Additionally, it preserves the atom order 25 | of the input SMILES string; thus, one could generate randomized SELFIES 26 | strings by generating randomized SMILES strings, and then translating them. 27 | 28 | By nature of SELFIES, it is impossible to represent molecules that 29 | violate the current semantic constraints as SELFIES strings. 30 | Thus, we provide the ``strict`` flag to guard against such cases. If 31 | ``strict=True``, then this function will raise a 32 | :class:`selfies.EncoderError` if the input SMILES string represents 33 | a molecule that violates the semantic constraints. If 34 | ``strict=False``, then this function will not raise any error; however, 35 | calling :func:`selfies.decoder` on a SELFIES string generated this 36 | way will *not* be guaranteed to recover a SMILES string representing 37 | the original molecule. 38 | 39 | :param smiles: the SMILES string to be translated. It is recommended to 40 | use RDKit to check that the strings passed into this function 41 | are valid SMILES strings. 42 | :param strict: if ``True``, this function will check that the 43 | input SMILES string obeys the semantic constraints. 44 | Defaults to ``True``. 45 | :param attribute: if an attribution should be returned 46 | :return: a SELFIES string translated from the input SMILES string if 47 | attribute is ``False``, otherwise a tuple is returned of 48 | SELFIES string and attribution list. 49 | :raises EncoderError: if the input SMILES string is invalid, 50 | cannot be kekulized, or violates the semantic constraints with 51 | ``strict=True``. 52 | 53 | :Example: 54 | 55 | >>> import selfies as sf 56 | >>> sf.encoder("C=CF") 57 | '[C][=C][F]' 58 | 59 | .. note:: This function does not currently support SMILES with: 60 | 61 | * The wildcard symbol ``*``. 62 | * The quadruple bond symbol ``$``. 63 | * Chirality specifications other than ``@`` and ``@@``. 64 | * Ring bonds across a dot symbol (e.g. ``c1cc([O-].[Na+])ccc1``) or 65 | ring bonds between atoms that are over 4000 atoms apart. 66 | 67 | Although SELFIES does not have aromatic symbols, this function 68 | *does* support aromatic SMILES strings by internally kekulizing them 69 | before translation. 70 | """ 71 | 72 | try: 73 | mol = smiles_to_mol(smiles, attributable=attribute) 74 | except SMILESParserError as err: 75 | err_msg = "failed to parse input\n\tSMILES: {}".format(smiles) 76 | raise EncoderError(err_msg) from err 77 | 78 | if not mol.kekulize(): 79 | err_msg = "kekulization failed\n\tSMILES: {}".format(smiles) 80 | raise EncoderError(err_msg) 81 | 82 | if strict: 83 | _check_bond_constraints(mol, smiles) 84 | 85 | # invert chirality of atoms where necessary, 86 | # such that they are restored when the SELFIES is decoded 87 | for atom in mol.get_atoms(): 88 | if ((atom.chirality is not None) 89 | and mol.has_out_ring_bond(atom.index) 90 | and _should_invert_chirality(mol, atom)): 91 | atom.invert_chirality() 92 | 93 | fragments = [] 94 | attribution_maps = [] 95 | attribution_index = 0 96 | for root in mol.get_roots(): 97 | derived = list(_fragment_to_selfies( 98 | mol, None, root, attribution_maps, attribution_index)) 99 | attribution_index += len(derived) 100 | fragments.append("".join(derived)) 101 | # trim attribution map of empty tokens 102 | attribution_maps = [a for a in attribution_maps if a.token] 103 | result = ".".join(fragments), attribution_maps 104 | return result if attribute else result[0] 105 | 106 | 107 | def _check_bond_constraints(mol, smiles): 108 | errors = [] 109 | 110 | for atom in mol.get_atoms(): 111 | bond_cap = atom.bonding_capacity 112 | bond_count = mol.get_bond_count(atom.index) 113 | if bond_count > bond_cap: 114 | errors.append((atom_to_smiles(atom), bond_count, bond_cap)) 115 | 116 | if errors: 117 | err_msg = "input violates the currently-set semantic constraints\n" \ 118 | "\tSMILES: {}\n" \ 119 | "\tErrors:\n".format(smiles) 120 | for e in errors: 121 | err_msg += "\t[{:} with {} bond(s) - " \ 122 | "a max. of {} bond(s) was specified]\n".format(*e) 123 | raise EncoderError(err_msg) 124 | 125 | 126 | def _should_invert_chirality(mol, atom): 127 | out_bonds = mol.get_out_dirbonds(atom.index) 128 | 129 | # 1. rings whose right number are bonded to this atom (e.g. ...1...X1) 130 | # 2. rings whose left number are bonded to this atom (e.g. X1...1...) 131 | # 3. branches and other (e.g. X(...)...) 132 | partition = [[], [], []] 133 | for i, bond in enumerate(out_bonds): 134 | if not bond.ring_bond: 135 | partition[2].append(i) 136 | elif bond.src < bond.dst: 137 | partition[1].append(i) 138 | else: 139 | partition[0].append(i) 140 | partition[1].sort(key=lambda x: out_bonds[x].dst) 141 | 142 | # construct permutation 143 | perm = partition[0] + partition[1] + partition[2] 144 | count = 0 145 | for i in range(len(perm)): 146 | for j in range(i + 1, len(perm)): 147 | if perm[i] > perm[j]: 148 | count += 1 149 | return count % 2 != 0 # if odd permutation, should invert chirality 150 | 151 | 152 | def _fragment_to_selfies(mol, bond_into_root, root, 153 | attribution_maps, attribution_index=0): 154 | derived = [] 155 | 156 | bond_into_curr, curr = bond_into_root, root 157 | while True: 158 | curr_atom = mol.get_atom(curr) 159 | token = _atom_to_selfies(bond_into_curr, curr_atom) 160 | derived.append(token) 161 | 162 | attribution_maps.append(AttributionMap( 163 | len(derived) - 1 + attribution_index, 164 | token, mol.get_attribution(curr_atom))) 165 | 166 | out_bonds = mol.get_out_dirbonds(curr) 167 | for i, bond in enumerate(out_bonds): 168 | 169 | if bond.ring_bond: 170 | if bond.src < bond.dst: 171 | continue 172 | 173 | rev_bond = mol.get_dirbond(src=bond.dst, dst=bond.src) 174 | ring_len = bond.src - bond.dst 175 | Q_as_symbols = get_selfies_from_index(ring_len - 1) 176 | ring_symbol = "[{}Ring{}]".format( 177 | _ring_bonds_to_selfies(rev_bond, bond), 178 | len(Q_as_symbols) 179 | ) 180 | 181 | derived.append(ring_symbol) 182 | attribution_maps.append(AttributionMap( 183 | len(derived) - 1 + attribution_index, 184 | ring_symbol, mol.get_attribution(bond))) 185 | for symbol in Q_as_symbols: 186 | derived.append(symbol) 187 | attribution_maps.append(AttributionMap( 188 | len(derived) - 1 + attribution_index, 189 | symbol, mol.get_attribution(bond))) 190 | 191 | elif i == len(out_bonds) - 1: 192 | bond_into_curr, curr = bond, bond.dst 193 | 194 | else: 195 | # start, end are so we can go back and 196 | # correct offset from branch symbol in 197 | # branch tokens 198 | start = len(attribution_maps) 199 | branch = _fragment_to_selfies( 200 | mol, bond, bond.dst, attribution_maps, len(derived)) 201 | Q_as_symbols = get_selfies_from_index(len(branch) - 1) 202 | branch_symbol = "[{}Branch{}]".format( 203 | _bond_to_selfies(bond, show_stereo=False), 204 | len(Q_as_symbols) 205 | ) 206 | end = len(attribution_maps) 207 | 208 | derived.append(branch_symbol) 209 | for symbol in Q_as_symbols: 210 | derived.append(symbol) 211 | attribution_maps.append(AttributionMap( 212 | len(derived) - 1 + attribution_index, 213 | symbol, mol.get_attribution(bond))) 214 | 215 | # account for branch symbol because it is inserted after 216 | for j in range(start, end): 217 | attribution_maps[j].index += len(Q_as_symbols) + 1 218 | attribution_maps.append(AttributionMap( 219 | len(derived) - 1 + attribution_index, 220 | branch_symbol, mol.get_attribution(bond))) 221 | 222 | derived.extend(branch) 223 | 224 | # end of chain 225 | if (not out_bonds) or out_bonds[-1].ring_bond: 226 | break 227 | return derived 228 | 229 | 230 | def _bond_to_selfies(bond, show_stereo=True): 231 | if not show_stereo and (bond.order == 1): 232 | return "" 233 | return bond_to_smiles(bond) 234 | 235 | 236 | def _ring_bonds_to_selfies(lbond, rbond): 237 | assert lbond.order == rbond.order 238 | 239 | if (lbond.order != 1) or all(b.stereo is None for b in (lbond, rbond)): 240 | return _bond_to_selfies(lbond, show_stereo=False) 241 | else: 242 | bond_char = "-" if (lbond.stereo is None) else lbond.stereo 243 | bond_char += "-" if (rbond.stereo is None) else rbond.stereo 244 | return bond_char 245 | 246 | 247 | def _atom_to_selfies(bond, atom): 248 | assert not atom.is_aromatic 249 | bond_char = "" if (bond is None) else _bond_to_selfies(bond) 250 | return "[{}{}]".format(bond_char, atom_to_smiles(atom, brackets=False)) 251 | -------------------------------------------------------------------------------- /mofreinforce/libs/selfies/exceptions.py: -------------------------------------------------------------------------------- 1 | # This code is adapted from the "selfies" repository: 2 | # https://github.com/aspuru-guzik-group/selfies.git 3 | 4 | # The code in this file is licensed under the Apache License, Version 2.0: 5 | # https://github.com/aspuru-guzik-group/selfies/blob/master/LICENSE 6 | 7 | # Adapted by Hyunsoo Park, 2022 8 | 9 | class SMILESParserError(ValueError): 10 | """Exception raised when a SMILES fails to be parsed. 11 | """ 12 | 13 | def __init__(self, smiles, reason="N/A", idx=-1): 14 | self.smiles = smiles 15 | self.idx = idx 16 | self.reason = reason 17 | 18 | def __str__(self): 19 | err_msg = "\n" \ 20 | "\tSMILES: {smiles}\n" \ 21 | "\t {pointer}\n" \ 22 | "\tIndex: {index}\n" \ 23 | "\tReason: {reason}" 24 | 25 | return err_msg.format( 26 | smiles=self.smiles, 27 | pointer=(" " * self.idx + "^"), 28 | index=self.idx, 29 | reason=self.reason 30 | ) 31 | 32 | 33 | class EncoderError(Exception): 34 | """Exception raised by :func:`selfies.encoder`. 35 | """ 36 | 37 | pass 38 | 39 | 40 | class DecoderError(Exception): 41 | """Exception raised by :func:`selfies.decoder`. 42 | """ 43 | 44 | pass 45 | -------------------------------------------------------------------------------- /mofreinforce/libs/selfies/grammar_rules.py: -------------------------------------------------------------------------------- 1 | # This code is adapted from the "selfies" repository: 2 | # https://github.com/aspuru-guzik-group/selfies.git 3 | 4 | # The code in this file is licensed under the Apache License, Version 2.0: 5 | # https://github.com/aspuru-guzik-group/selfies/blob/master/LICENSE 6 | 7 | # Adapted by Hyunsoo Park, 2022 8 | 9 | import functools 10 | import itertools 11 | import re 12 | from typing import Any, List, Optional, Tuple 13 | 14 | from .constants import ( 15 | ELEMENTS, 16 | INDEX_ALPHABET, 17 | INDEX_CODE, 18 | ORGANIC_SUBSET 19 | ) 20 | from .mol_graph import Atom 21 | from .utils.smiles_utils import smiles_to_bond 22 | 23 | 24 | def process_atom_symbol(symbol: str) -> Optional[Tuple[Any, Atom]]: 25 | try: 26 | output = _PROCESS_ATOM_CACHE[symbol] 27 | except KeyError: 28 | output = _process_atom_selfies_no_cache(symbol) 29 | if output is None: 30 | return None 31 | _PROCESS_ATOM_CACHE[symbol] = output 32 | 33 | bond_info, atom_fac = output 34 | atom = atom_fac() 35 | if atom.bonding_capacity < 0: 36 | return None # too many Hs (e.g. [CH9] 37 | return bond_info, atom 38 | 39 | 40 | def process_branch_symbol(symbol: str) -> Optional[Tuple[int, int]]: 41 | try: 42 | return _PROCESS_BRANCH_CACHE[symbol] 43 | except KeyError: 44 | return None 45 | 46 | 47 | def process_ring_symbol(symbol: str) -> Optional[Tuple[int, int, Any]]: 48 | try: 49 | return _PROCESS_RING_CACHE[symbol] 50 | except KeyError: 51 | return None 52 | 53 | 54 | def next_atom_state( 55 | bond_order: int, bond_cap: int, state: int 56 | ) -> Tuple[int, Optional[int]]: 57 | if state == 0: 58 | bond_order = 0 59 | 60 | bond_order = min(bond_order, state, bond_cap) 61 | bonds_left = bond_cap - bond_order 62 | next_state = None if (bonds_left == 0) else bonds_left 63 | return bond_order, next_state 64 | 65 | 66 | def next_branch_state( 67 | branch_type: int, state: int 68 | ) -> Tuple[int, Optional[int]]: 69 | assert 1 <= branch_type <= 3 70 | assert state > 1 71 | 72 | branch_init_state = min(state - 1, branch_type) 73 | next_state = state - branch_init_state 74 | return branch_init_state, next_state 75 | 76 | 77 | def next_ring_state( 78 | ring_type: int, state: int 79 | ) -> Tuple[int, Optional[int]]: 80 | assert state > 0 81 | 82 | bond_order = min(ring_type, state) 83 | bonds_left = state - bond_order 84 | next_state = None if (bonds_left == 0) else bonds_left 85 | return bond_order, next_state 86 | 87 | 88 | def get_index_from_selfies(*symbols: List[str]) -> int: 89 | index = 0 90 | for i, c in enumerate(reversed(symbols)): 91 | index += INDEX_CODE.get(c, 0) * (len(INDEX_CODE) ** i) 92 | return index 93 | 94 | 95 | def get_selfies_from_index(index: int) -> List[str]: 96 | if index < 0: 97 | raise IndexError() 98 | elif index == 0: 99 | return [INDEX_ALPHABET[0]] 100 | 101 | symbols = [] 102 | base = len(INDEX_ALPHABET) 103 | while index: 104 | symbols.append(INDEX_ALPHABET[index % base]) 105 | index //= base 106 | return symbols[::-1] 107 | 108 | 109 | # ============================================================================= 110 | # Caches (for computational speed) 111 | # ============================================================================= 112 | 113 | 114 | SELFIES_ATOM_PATTERN = re.compile( 115 | r"^[\[]" # opening square bracket [ 116 | r"([=#/\\]?)" # bond char 117 | r"(\d*)" # isotope number (optional, e.g. 123, 26) 118 | r"([A-Z][a-z]?|\*)" # element symbol 119 | r"([@]{0,2})" # chiral_tag (optional, only @ and @@ supported) 120 | r"((?:[H]\d)?)" # H count (optional, e.g. H1, H3) 121 | r"((?:[+-][1-9]+)?)" # charge (optional, e.g. +1) 122 | r"[]]$" # closing square bracket ] 123 | ) 124 | 125 | 126 | def _process_atom_selfies_no_cache(symbol): 127 | m = SELFIES_ATOM_PATTERN.match(symbol) 128 | if m is None: 129 | return None 130 | bond_char, isotope, element, chirality, h_count, charge = m.groups() 131 | 132 | if symbol[1 + len(bond_char):-1] in ORGANIC_SUBSET: 133 | atom_fac = functools.partial(Atom, element=element, is_aromatic=False) 134 | return smiles_to_bond(bond_char), atom_fac 135 | 136 | isotope = None if (isotope == "") else int(isotope) 137 | if element not in ELEMENTS: 138 | return None 139 | chirality = None if (chirality == "") else chirality 140 | 141 | s = h_count 142 | if s == "": 143 | h_count = 0 144 | else: 145 | h_count = int(s[1:]) 146 | 147 | s = charge 148 | if s == "": 149 | charge = 0 150 | else: 151 | charge = int(s[1:]) 152 | charge *= 1 if (s[0] == "+") else -1 153 | 154 | atom_fac = functools.partial( 155 | Atom, 156 | element=element, 157 | is_aromatic=False, 158 | isotope=isotope, 159 | chirality=chirality, 160 | h_count=h_count, 161 | charge=charge 162 | ) 163 | 164 | return smiles_to_bond(bond_char), atom_fac 165 | 166 | 167 | def _build_atom_cache(): 168 | cache = dict() 169 | common_symbols = [ 170 | "[#C+1]", "[#C-1]", "[#C]", "[#N+1]", "[#N]", "[#O+1]", "[#P+1]", 171 | "[#P-1]", "[#P]", "[#S+1]", "[#S-1]", "[#S]", "[=C+1]", "[=C-1]", 172 | "[=C]", "[=N+1]", "[=N-1]", "[=N]", "[=O+1]", "[=O]", "[=P+1]", 173 | "[=P-1]", "[=P]", "[=S+1]", "[=S-1]", "[=S]", "[Br]", "[C+1]", "[C-1]", 174 | "[C]", "[Cl]", "[F]", "[H]", "[I]", "[N+1]", "[N-1]", "[N]", "[O+1]", 175 | "[O-1]", "[O]", "[P+1]", "[P-1]", "[P]", "[S+1]", "[S-1]", "[S]" 176 | ] 177 | 178 | for symbol in common_symbols: 179 | cache[symbol] = _process_atom_selfies_no_cache(symbol) 180 | return cache 181 | 182 | 183 | def _build_branch_cache(): 184 | cache = dict() 185 | for L in range(1, 4): 186 | for bond_char in ["", "=", "#"]: 187 | symbol = "[{}Branch{}]".format(bond_char, L) 188 | cache[symbol] = (smiles_to_bond(bond_char)[0], L) 189 | return cache 190 | 191 | 192 | def _build_ring_cache(): 193 | cache = dict() 194 | for L in range(1, 4): 195 | # [RingL], [=RingL], [#RingL] 196 | for bond_char in ["", "=", "#"]: 197 | symbol = "[{}Ring{}]".format(bond_char, L) 198 | order, stereo = smiles_to_bond(bond_char) 199 | cache[symbol] = (order, L, (stereo, stereo)) 200 | 201 | # [-/RingL], [\/RingL], [\-RingL], ... 202 | for lchar, rchar in itertools.product(["-", "/", "\\"], repeat=2): 203 | if lchar == rchar == "-": 204 | continue 205 | symbol = "[{}{}Ring{}]".format(lchar, rchar, L) 206 | order, lstereo = smiles_to_bond(lchar) 207 | order, rstereo = smiles_to_bond(rchar) 208 | cache[symbol] = (order, L, (lstereo, rstereo)) 209 | return cache 210 | 211 | 212 | _PROCESS_ATOM_CACHE = _build_atom_cache() 213 | 214 | _PROCESS_BRANCH_CACHE = _build_branch_cache() 215 | 216 | _PROCESS_RING_CACHE = _build_ring_cache() 217 | -------------------------------------------------------------------------------- /mofreinforce/libs/selfies/mol_graph.py: -------------------------------------------------------------------------------- 1 | # This code is adapted from the "selfies" repository: 2 | # https://github.com/aspuru-guzik-group/selfies.git 3 | 4 | # The code in this file is licensed under the Apache License, Version 2.0: 5 | # https://github.com/aspuru-guzik-group/selfies/blob/master/LICENSE 6 | 7 | # Adapted by Hyunsoo Park, 2022 8 | 9 | import functools 10 | import itertools 11 | from typing import List, Optional, Union 12 | from dataclasses import dataclass, field 13 | 14 | from .bond_constraints import get_bonding_capacity 15 | from .constants import AROMATIC_VALENCES 16 | from .utils.matching_utils import find_perfect_matching 17 | 18 | 19 | @dataclass 20 | class Attribution: 21 | """A dataclass that contains token string and its index. 22 | """ 23 | #: token index 24 | index: int 25 | #: token string 26 | token: str 27 | 28 | 29 | @dataclass 30 | class AttributionMap: 31 | """A mapping from input to single output token showing which 32 | input tokens created the output token. 33 | """ 34 | #: Index of output token 35 | index: int 36 | #: Output token 37 | token: str 38 | #: List of input tokens that created the output token 39 | attribution: List[Attribution] = field(default_factory=list) 40 | 41 | 42 | class Atom: 43 | """An atom with associated specifications (e.g. charge, chirality). 44 | """ 45 | 46 | def __init__( 47 | self, 48 | element: str, 49 | is_aromatic: bool, 50 | isotope: Optional[int] = None, 51 | chirality: Optional[str] = None, 52 | h_count: Optional[int] = None, 53 | charge: int = 0 54 | ): 55 | self.index = None 56 | self.element = element 57 | self.is_aromatic = is_aromatic 58 | self.isotope = isotope 59 | self.chirality = chirality 60 | self.h_count = h_count 61 | self.charge = charge 62 | 63 | @property 64 | @functools.lru_cache() 65 | def bonding_capacity(self): 66 | bond_cap = get_bonding_capacity(self.element, self.charge) 67 | bond_cap -= 0 if (self.h_count is None) else self.h_count 68 | return bond_cap 69 | 70 | def invert_chirality(self) -> None: 71 | if self.chirality == "@": 72 | self.chirality = "@@" 73 | elif self.chirality == "@@": 74 | self.chirality = "@" 75 | 76 | 77 | class DirectedBond: 78 | """A bond that contains directional information. 79 | """ 80 | 81 | def __init__( 82 | self, 83 | src: int, 84 | dst: int, 85 | order: Union[int, float], 86 | stereo: Optional[str], 87 | ring_bond: bool 88 | ): 89 | self.src = src 90 | self.dst = dst 91 | self.order = order 92 | self.stereo = stereo 93 | self.ring_bond = ring_bond 94 | 95 | 96 | class MolecularGraph: 97 | """A molecular graph. 98 | 99 | Molecules can be viewed as weighted undirected graphs. However, SMILES 100 | and SELFIES strings are more naturally represented as weighted directed 101 | graphs, where the direction of the edges specifies the order of atoms 102 | and bonds in the string. 103 | """ 104 | 105 | def __init__(self, attributable=False): 106 | self._roots = list() # stores root atoms, where traversal begins 107 | self._atoms = list() # stores atoms in this graph 108 | self._bond_dict = dict() # stores all bonds in this graph 109 | self._adj_list = list() # adjacency list, representing this graph 110 | self._bond_counts = list() # stores number of bonds an atom has made 111 | self._ring_bond_flags = list() # stores if an atom makes a ring bond 112 | self._delocal_subgraph = dict() # delocalization subgraph 113 | self._attribution = dict() # attribution of each atom/bond 114 | self._attributable = attributable 115 | 116 | def __len__(self): 117 | return len(self._atoms) 118 | 119 | def has_bond(self, a: int, b: int) -> bool: 120 | if a > b: 121 | a, b = b, a 122 | return (a, b) in self._bond_dict 123 | 124 | def has_out_ring_bond(self, src: int) -> bool: 125 | return self._ring_bond_flags[src] 126 | 127 | def get_attribution( 128 | self, 129 | o: Union[DirectedBond, Atom] 130 | ) -> List[Attribution]: 131 | if self._attributable and o in self._attribution: 132 | return self._attribution[o] 133 | return None 134 | 135 | def get_roots(self) -> List[int]: 136 | return self._roots 137 | 138 | def get_atom(self, idx: int) -> Atom: 139 | return self._atoms[idx] 140 | 141 | def get_atoms(self) -> List[Atom]: 142 | return self._atoms 143 | 144 | def get_dirbond(self, src, dst) -> DirectedBond: 145 | return self._bond_dict[(src, dst)] 146 | 147 | def get_out_dirbonds(self, src: int) -> List[DirectedBond]: 148 | return self._adj_list[src] 149 | 150 | def get_bond_count(self, idx: int) -> int: 151 | return self._bond_counts[idx] 152 | 153 | def add_atom(self, atom: Atom, mark_root: bool = False) -> Atom: 154 | atom.index = len(self) 155 | 156 | if mark_root: 157 | self._roots.append(atom.index) 158 | self._atoms.append(atom) 159 | self._adj_list.append(list()) 160 | self._bond_counts.append(0) 161 | self._ring_bond_flags.append(False) 162 | if atom.is_aromatic: 163 | self._delocal_subgraph[atom.index] = list() 164 | return atom 165 | 166 | def add_attribution( 167 | self, 168 | o: Union[DirectedBond, Atom], 169 | attr: List[Attribution] 170 | ) -> None: 171 | if self._attributable: 172 | if o in self._attribution: 173 | self._attribution[o].extend(attr) 174 | else: 175 | self._attribution[o] = attr 176 | 177 | def add_bond( 178 | self, src: int, dst: int, 179 | order: Union[int, float], stereo: str 180 | ) -> DirectedBond: 181 | assert src < dst 182 | 183 | bond = DirectedBond(src, dst, order, stereo, False) 184 | self._add_bond_at_loc(bond, -1) 185 | self._bond_counts[src] += order 186 | self._bond_counts[dst] += order 187 | 188 | if order == 1.5: 189 | self._delocal_subgraph.setdefault(src, []).append(dst) 190 | self._delocal_subgraph.setdefault(dst, []).append(src) 191 | return bond 192 | 193 | def add_placeholder_bond(self, src: int) -> int: 194 | out_edges = self._adj_list[src] 195 | out_edges.append(None) 196 | return len(out_edges) - 1 197 | 198 | def add_ring_bond( 199 | self, a: int, b: int, 200 | order: Union[int, float], 201 | a_stereo: Optional[str], b_stereo: Optional[str], 202 | a_pos: int = -1, b_pos: int = -1 203 | ) -> None: 204 | a_bond = DirectedBond(a, b, order, a_stereo, True) 205 | b_bond = DirectedBond(b, a, order, b_stereo, True) 206 | self._add_bond_at_loc(a_bond, a_pos) 207 | self._add_bond_at_loc(b_bond, b_pos) 208 | self._bond_counts[a] += order 209 | self._bond_counts[b] += order 210 | self._ring_bond_flags[a] = True 211 | self._ring_bond_flags[b] = True 212 | 213 | if order == 1.5: 214 | self._delocal_subgraph.setdefault(a, []).append(b) 215 | self._delocal_subgraph.setdefault(b, []).append(a) 216 | 217 | def update_bond_order( 218 | self, a: int, b: int, 219 | new_order: Union[int, float] 220 | ) -> None: 221 | assert 1 <= new_order <= 3 222 | 223 | if a > b: 224 | a, b = b, a # swap so that a < b 225 | a_to_b = self._bond_dict[(a, b)] # prev step guarantees existence 226 | if new_order == a_to_b.order: 227 | return 228 | elif a_to_b.ring_bond: 229 | b_to_a = self._bond_dict[(b, a)] 230 | bonds = (a_to_b, b_to_a) 231 | else: 232 | bonds = (a_to_b,) 233 | 234 | old_order = bonds[0].order 235 | for bond in bonds: 236 | bond.order = new_order 237 | self._bond_counts[a] += (new_order - old_order) 238 | self._bond_counts[b] += (new_order - old_order) 239 | 240 | def _add_bond_at_loc(self, bond, pos): 241 | self._bond_dict[(bond.src, bond.dst)] = bond 242 | 243 | out_edges = self._adj_list[bond.src] 244 | if (pos == -1) or (pos == len(out_edges)): 245 | out_edges.append(bond) 246 | elif out_edges[pos] is None: 247 | out_edges[pos] = bond 248 | else: 249 | out_edges.insert(pos, bond) 250 | 251 | def is_kekulized(self) -> bool: 252 | return not self._delocal_subgraph 253 | 254 | def kekulize(self) -> bool: 255 | # Algorithm based on Depth-First article by Richard L. Apodaca 256 | # Reference: 257 | # https://depth-first.com/articles/2020/02/10/ 258 | # a-comprehensive-treatment-of-aromaticity-in-the-smiles-language/ 259 | 260 | if self.is_kekulized(): 261 | return True 262 | 263 | ds = self._delocal_subgraph 264 | kept_nodes = set(itertools.filterfalse(self._prune_from_ds, ds)) 265 | 266 | # relabel kept DS nodes to be 0, 1, 2, ... 267 | label_to_node = list(sorted(kept_nodes)) 268 | node_to_label = {v: i for i, v in enumerate(label_to_node)} 269 | 270 | # pruned and relabelled DS 271 | pruned_ds = [list() for _ in range(len(kept_nodes))] 272 | for node in kept_nodes: 273 | label = node_to_label[node] 274 | for adj in filter(lambda v: v in kept_nodes, ds[node]): 275 | pruned_ds[label].append(node_to_label[adj]) 276 | 277 | matching = find_perfect_matching(pruned_ds) 278 | if matching is None: 279 | return False 280 | 281 | # de-aromatize and then make double bonds 282 | for node in ds: 283 | for adj in ds[node]: 284 | self.update_bond_order(node, adj, new_order=1) 285 | self._atoms[node].is_aromatic = False 286 | self._bond_counts[node] = int(self._bond_counts[node]) 287 | 288 | for matched_labels in enumerate(matching): 289 | matched_nodes = tuple(label_to_node[i] for i in matched_labels) 290 | self.update_bond_order(*matched_nodes, new_order=2) 291 | 292 | self._delocal_subgraph = dict() # clear DS 293 | return True 294 | 295 | def _prune_from_ds(self, node): 296 | adj_nodes = self._delocal_subgraph[node] 297 | if not adj_nodes: 298 | return True # aromatic atom with no aromatic bonds 299 | 300 | atom = self._atoms[node] 301 | valences = AROMATIC_VALENCES[atom.element] 302 | 303 | # each bond in DS has order 1.5 - we treat them as single bonds 304 | used_electrons = int(self._bond_counts[node] - 0.5 * len(adj_nodes)) 305 | 306 | if atom.h_count is None: # account for implicit Hs 307 | assert atom.charge == 0 308 | return any(used_electrons == v for v in valences) 309 | else: 310 | valence = valences[-1] - atom.charge 311 | used_electrons += atom.h_count 312 | free_electrons = valence - used_electrons 313 | return not ((free_electrons >= 0) and (free_electrons % 2 != 0)) 314 | -------------------------------------------------------------------------------- /mofreinforce/libs/selfies/utils/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /mofreinforce/libs/selfies/utils/encoding_utils.py: -------------------------------------------------------------------------------- 1 | # This code is adapted from the "selfies" repository: 2 | # https://github.com/aspuru-guzik-group/selfies.git 3 | 4 | # The code in this file is licensed under the Apache License, Version 2.0: 5 | # https://github.com/aspuru-guzik-group/selfies/blob/master/LICENSE 6 | 7 | # Adapted by Hyunsoo Park, 2022 8 | 9 | from typing import Dict, List, Tuple, Union 10 | 11 | from .selfies_utils import len_selfies, split_selfies 12 | 13 | 14 | def selfies_to_encoding( 15 | selfies: str, 16 | vocab_stoi: Dict[str, int], 17 | pad_to_len: int = -1, 18 | enc_type: str = 'both' 19 | ) -> Union[List[int], List[List[int]], Tuple[List[int], List[List[int]]]]: 20 | """Converts a SELFIES string into its label (integer) 21 | and/or one-hot encoding. 22 | 23 | A label encoded output will be a list of shape ``(L,)`` and a 24 | one-hot encoded output will be a 2D list of shape ``(L, len(vocab_stoi))``, 25 | where ``L`` is the symbol length of the SELFIES string. Optionally, 26 | the SELFIES string can be padded before it is encoded. 27 | 28 | :param selfies: the SELFIES string to be encoded. 29 | :param vocab_stoi: a dictionary that maps SELFIES symbols to indices, 30 | which must be non-negative and contiguous, starting from 0. 31 | If the SELFIES string is to be padded, then the special padding symbol 32 | ``[nop]`` must also be a key in this dictionary. 33 | :param pad_to_len: the length that the SELFIES string string is padded to. 34 | If this value is less than or equal to the symbol length of the 35 | SELFIES string, then no padding is added. Defaults to ``-1``. 36 | :param enc_type: the type of encoding of the output: 37 | ``label`` or ``one_hot`` or ``both``. 38 | If this value is ``both``, then a tuple of the label and one-hot 39 | encodings is returned. Defaults to ``both``. 40 | :return: the label encoded and/or one-hot encoded SELFIES string. 41 | 42 | :Example: 43 | 44 | >>> import selfies as sf 45 | >>> sf.selfies_to_encoding("[C][F]", {"[C]": 0, "[F]": 1}) 46 | ([0, 1], [[1, 0], [0, 1]]) 47 | """ 48 | 49 | # some error checking 50 | if enc_type not in ("label", "one_hot", "both"): 51 | raise ValueError("enc_type must be in ('label', 'one_hot', 'both')") 52 | 53 | # pad with [nop] 54 | if pad_to_len > len_selfies(selfies): 55 | selfies += "[nop]" * (pad_to_len - len_selfies(selfies)) 56 | 57 | # integer encode 58 | char_list = split_selfies(selfies) 59 | integer_encoded = [vocab_stoi[char] for char in char_list] 60 | 61 | if enc_type == "label": 62 | return integer_encoded 63 | 64 | # one-hot encode 65 | one_hot_encoded = list() 66 | for index in integer_encoded: 67 | letter = [0] * len(vocab_stoi) 68 | letter[index] = 1 69 | one_hot_encoded.append(letter) 70 | 71 | if enc_type == "one_hot": 72 | return one_hot_encoded 73 | return integer_encoded, one_hot_encoded 74 | 75 | 76 | def encoding_to_selfies( 77 | encoding: Union[List[int], List[List[int]]], 78 | vocab_itos: Dict[int, str], 79 | enc_type: str, 80 | ) -> str: 81 | """Converts a label (integer) or one-hot encoding into a SELFIES string. 82 | 83 | If the input is label encoded, then a list of shape ``(L,)`` is 84 | expected; and if the input is one-hot encoded, then a 2D list of 85 | shape ``(L, len(vocab_itos))`` is expected. 86 | 87 | :param encoding: a label or one-hot encoding. 88 | :param vocab_itos: a dictionary that maps indices to SELFIES symbols. 89 | The indices of this dictionary must be non-negative and contiguous, 90 | starting from 0. 91 | :param enc_type: the type of encoding of the input: 92 | ``label`` or ``one_hot``. 93 | :return: the SELFIES string represented by the input encoding. 94 | 95 | :Example: 96 | 97 | >>> import selfies as sf 98 | >>> one_hot = [[0, 1, 0], [0, 0, 1], [1, 0, 0]] 99 | >>> vocab_itos = {0: "[nop]", 1: "[C]", 2: "[F]"} 100 | >>> sf.encoding_to_selfies(one_hot, vocab_itos, enc_type="one_hot") 101 | '[C][F][nop]' 102 | """ 103 | 104 | if enc_type not in ("label", "one_hot"): 105 | raise ValueError("enc_type must be in ('label', 'one_hot')") 106 | 107 | if enc_type == "one_hot": # Get integer encoding 108 | integer_encoded = [] 109 | for row in encoding: 110 | integer_encoded.append(row.index(1)) 111 | else: 112 | integer_encoded = encoding 113 | 114 | # Integer encoding -> SELFIES 115 | char_list = [vocab_itos[i] for i in integer_encoded] 116 | selfies = "".join(char_list) 117 | 118 | return selfies 119 | 120 | 121 | def batch_selfies_to_flat_hot( 122 | selfies_batch: List[str], 123 | vocab_stoi: Dict[str, int], 124 | pad_to_len: int = -1, 125 | ) -> List[List[int]]: 126 | """Converts a list of SELFIES strings into its list of flattened 127 | one-hot encodings. 128 | 129 | Each SELFIES string in the input list is one-hot encoded 130 | (and then flattened) using :func:`selfies.selfies_to_encoding`, with 131 | ``vocab_stoi`` and ``pad_to_len`` being passed in as arguments. 132 | 133 | :param selfies_batch: the list of SELFIES strings to be encoded. 134 | :param vocab_stoi: a dictionary that maps SELFIES symbols to indices. 135 | :param pad_to_len: the length that each SELFIES string in the input list 136 | is padded to. Defaults to ``-1``. 137 | :return: the flattened one-hot encodings of the input list. 138 | 139 | :Example: 140 | 141 | >>> import selfies as sf 142 | >>> batch = ["[C]", "[C][C]"] 143 | >>> vocab_stoi = {"[nop]": 0, "[C]": 1} 144 | >>> sf.batch_selfies_to_flat_hot(batch, vocab_stoi, 2) 145 | [[0, 1, 1, 0], [0, 1, 0, 1]] 146 | """ 147 | 148 | hot_list = list() 149 | 150 | for selfies in selfies_batch: 151 | one_hot = selfies_to_encoding(selfies, vocab_stoi, pad_to_len, 152 | enc_type="one_hot") 153 | flattened = [elem for vec in one_hot for elem in vec] 154 | hot_list.append(flattened) 155 | 156 | return hot_list 157 | 158 | 159 | def batch_flat_hot_to_selfies( 160 | one_hot_batch: List[List[int]], 161 | vocab_itos: Dict[int, str], 162 | ) -> List[str]: 163 | """Converts a list of flattened one-hot encodings into a list 164 | of SELFIES strings. 165 | 166 | Each encoding in the input list is unflattened and then decoded using 167 | :func:`selfies.encoding_to_selfies`, with ``vocab_itos`` being passed in 168 | as an argument. 169 | 170 | :param one_hot_batch: a list of flattened one-hot encodings. Each 171 | encoding must be a list of length divisible by ``len(vocab_itos)``. 172 | :param vocab_itos: a dictionary that maps indices to SELFIES symbols. 173 | :return: the list of SELFIES strings represented by the input encodings. 174 | 175 | :Example: 176 | 177 | >>> import selfies as sf 178 | >>> batch = [[0, 1, 1, 0], [0, 1, 0, 1]] 179 | >>> vocab_itos = {0: "[nop]", 1: "[C]"} 180 | >>> sf.batch_flat_hot_to_selfies(batch, vocab_itos) 181 | ['[C][nop]', '[C][C]'] 182 | """ 183 | 184 | selfies_list = [] 185 | 186 | for flat_one_hot in one_hot_batch: 187 | 188 | # Reshape to an L x M array where each column represents an alphabet 189 | # entry and each row is a position in the selfies 190 | one_hot = [] 191 | 192 | M = len(vocab_itos) 193 | if len(flat_one_hot) % M != 0: 194 | raise ValueError("size of vector in one_hot_batch not divisible " 195 | "by the length of the vocabulary.") 196 | L = len(flat_one_hot) // M 197 | 198 | for i in range(L): 199 | one_hot.append(flat_one_hot[M * i: M * (i + 1)]) 200 | 201 | selfies = encoding_to_selfies(one_hot, vocab_itos, enc_type="one_hot") 202 | selfies_list.append(selfies) 203 | 204 | return selfies_list 205 | -------------------------------------------------------------------------------- /mofreinforce/libs/selfies/utils/matching_utils.py: -------------------------------------------------------------------------------- 1 | # This code is adapted from the "selfies" repository: 2 | # https://github.com/aspuru-guzik-group/selfies.git 3 | 4 | # The code in this file is licensed under the Apache License, Version 2.0: 5 | # https://github.com/aspuru-guzik-group/selfies/blob/master/LICENSE 6 | 7 | # Adapted by Hyunsoo Park, 2022 8 | 9 | import heapq 10 | import itertools 11 | from collections import deque 12 | from typing import List, Optional 13 | 14 | 15 | def find_perfect_matching(graph: List[List[int]]) -> Optional[List[int]]: 16 | """Finds a perfect matching for an undirected graph (without self-loops). 17 | 18 | :param graph: an adjacency list representing the input graph. 19 | :return: a list representing a perfect matching, where j is the i-th 20 | element if nodes i and j are matched. Returns None, if the graph cannot 21 | be perfectly matched. 22 | """ 23 | 24 | # start with a maximal matching for efficiency 25 | matching = _greedy_matching(graph) 26 | 27 | unmatched = set(i for i in range(len(graph)) if matching[i] is None) 28 | while unmatched: 29 | 30 | # find augmenting path which starts at root 31 | root = unmatched.pop() 32 | path = _find_augmenting_path(graph, root, matching) 33 | 34 | if path is None: 35 | return None 36 | else: 37 | _flip_augmenting_path(matching, path) 38 | unmatched.discard(path[0]) 39 | unmatched.discard(path[-1]) 40 | 41 | return matching 42 | 43 | 44 | def _greedy_matching(graph): 45 | matching = [None] * len(graph) 46 | free_degrees = [len(graph[i]) for i in range(len(graph))] 47 | # free_degrees[i] = number of unmatched neighbors for node i 48 | 49 | # prioritize nodes with fewer unmatched neighbors 50 | node_pqueue = [(free_degrees[i], i) for i in range(len(graph))] 51 | heapq.heapify(node_pqueue) 52 | 53 | while node_pqueue: 54 | _, node = heapq.heappop(node_pqueue) 55 | 56 | if (matching[node] is not None) or (free_degrees[node] == 0): 57 | continue # node cannot be matched 58 | 59 | # match node with first unmatched neighbor 60 | mate = next(i for i in graph[node] if matching[i] is None) 61 | matching[node] = mate 62 | matching[mate] = node 63 | 64 | for adj in itertools.chain(graph[node], graph[mate]): 65 | free_degrees[adj] -= 1 66 | if (matching[adj] is None) and (free_degrees[adj] > 0): 67 | heapq.heappush(node_pqueue, (free_degrees[adj], adj)) 68 | 69 | return matching 70 | 71 | 72 | def _find_augmenting_path(graph, root, matching): 73 | assert matching[root] is None 74 | 75 | # run modified BFS to find path from root to unmatched node 76 | other_end = None 77 | node_queue = deque([root]) 78 | 79 | # parent BFS tree - None indicates an unvisited node 80 | parents = [None] * len(graph) 81 | parents[root] = [None, None] 82 | 83 | while node_queue: 84 | node = node_queue.popleft() 85 | 86 | for adj in graph[node]: 87 | if matching[adj] is None: # unmatched node 88 | if adj != root: # augmenting path found! 89 | parents[adj] = [node, adj] 90 | other_end = adj 91 | break 92 | else: 93 | adj_mate = matching[adj] 94 | if parents[adj_mate] is None: # adj_mate not visited 95 | parents[adj_mate] = [node, adj] 96 | node_queue.append(adj_mate) 97 | 98 | if other_end is not None: 99 | break # augmenting path found! 100 | 101 | if other_end is None: 102 | return None 103 | else: 104 | path = [] 105 | node = other_end 106 | while node != root: 107 | path.append(parents[node][1]) 108 | path.append(parents[node][0]) 109 | node = parents[node][0] 110 | return path 111 | 112 | 113 | def _flip_augmenting_path(matching, path): 114 | for i in range(0, len(path), 2): 115 | a, b = path[i], path[i + 1] 116 | matching[a] = b 117 | matching[b] = a 118 | -------------------------------------------------------------------------------- /mofreinforce/libs/selfies/utils/selfies_utils.py: -------------------------------------------------------------------------------- 1 | # This code is adapted from the "selfies" repository: 2 | # https://github.com/aspuru-guzik-group/selfies.git 3 | 4 | # The code in this file is licensed under the Apache License, Version 2.0: 5 | # https://github.com/aspuru-guzik-group/selfies/blob/master/LICENSE 6 | 7 | # Adapted by Hyunsoo Park, 2022 8 | 9 | from typing import Iterable, Iterator, Set 10 | 11 | 12 | def len_selfies(selfies: str) -> int: 13 | """Returns the number of symbols in a given SELFIES string. 14 | 15 | :param selfies: a SELFIES string. 16 | :return: the symbol length of the SELFIES string. 17 | 18 | :Example: 19 | 20 | >>> import selfies as sf 21 | >>> sf.len_selfies("[C][=C][F].[C]") 22 | 5 23 | """ 24 | 25 | return selfies.count("[") + selfies.count(".") 26 | 27 | 28 | def split_selfies(selfies: str) -> Iterator[str]: 29 | """Tokenizes a SELFIES string into its individual symbols. 30 | 31 | :param selfies: a SELFIES string. 32 | :return: the symbols of the SELFIES string one-by-one with order preserved. 33 | 34 | :Example: 35 | 36 | >>> import selfies as sf 37 | >>> list(sf.split_selfies("[C][=C][F].[C]")) 38 | ['[C]', '[=C]', '[F]', '.', '[C]'] 39 | """ 40 | 41 | left_idx = selfies.find("[") 42 | 43 | while 0 <= left_idx < len(selfies): 44 | right_idx = selfies.find("]", left_idx + 1) 45 | if right_idx == -1: 46 | raise ValueError("malformed SELFIES string, hanging '[' bracket") 47 | 48 | next_symbol = selfies[left_idx: right_idx + 1] 49 | yield next_symbol 50 | 51 | left_idx = right_idx + 1 52 | if selfies[left_idx: left_idx + 1] == ".": 53 | yield "." 54 | left_idx += 1 55 | 56 | 57 | def get_alphabet_from_selfies(selfies_iter: Iterable[str]) -> Set[str]: 58 | """Constructs an alphabet from an iterable of SELFIES strings. 59 | 60 | The returned alphabet is the set of all symbols that appear in the 61 | SELFIES strings from the input iterable, minus the dot ``.`` symbol. 62 | 63 | :param selfies_iter: an iterable of SELFIES strings. 64 | :return: an alphabet of SELFIES symbols, built from the input iterable. 65 | 66 | :Example: 67 | 68 | >>> import selfies as sf 69 | >>> selfies_list = ["[C][F][O]", "[C].[O]", "[F][F]"] 70 | >>> alphabet = sf.get_alphabet_from_selfies(selfies_list) 71 | >>> sorted(list(alphabet)) 72 | ['[C]', '[F]', '[O]'] 73 | """ 74 | 75 | alphabet = set() 76 | for s in selfies_iter: 77 | for symbol in split_selfies(s): 78 | alphabet.add(symbol) 79 | alphabet.discard(".") 80 | return alphabet 81 | -------------------------------------------------------------------------------- /mofreinforce/predictor/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hspark1212/MOFreinforce/b920084a29db9223482ddfeffd2e775b82f49e63/mofreinforce/predictor/__init__.py -------------------------------------------------------------------------------- /mofreinforce/predictor/baseline_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 4 | 5 | 6 | class BaselineModel(nn.Module): 7 | def __init__(self, vocab_dim, mc_dim, topo_dim, embed_dim, hid_dim): 8 | super(BaselineModel, self).__init__() 9 | # model 10 | # mc 11 | """ 12 | self.embedding_mc = nn.Embedding(self.mc_dim, self.embed_dim) 13 | self.rc_mc = nn.Sequential( 14 | nn.Linear(self.embed_dim, 1), 15 | ) 16 | """ 17 | self.rc_mc = nn.Linear(mc_dim, 1) 18 | # topo 19 | self.embedding_topo = nn.Embedding(topo_dim, embed_dim) 20 | self.rc_topo = nn.Sequential( 21 | nn.Linear(embed_dim, 1), 22 | ) 23 | # ol 24 | self.embedding_ol = nn.Embedding(vocab_dim, embed_dim) 25 | self.rnn = nn.RNN(input_size=embed_dim, hidden_size=hid_dim, num_layers=1) 26 | self.rc_ol = nn.Linear(hid_dim, 1) 27 | # total 28 | self.rc_total = nn.Linear(3, 1) 29 | self.sigmoid = nn.Sigmoid() 30 | 31 | def forward(self, batch): 32 | mc = batch["mc"] 33 | topo = batch["topo"] 34 | ol_pad = batch["ol_pad"] 35 | ol_len = batch["ol_len"] 36 | # mc 37 | # logit_mc = self.embedding_mc(mc) # [B, embed_dim] 38 | # logit_mc = self.rc_mc(logit_mc) # [B, 1] 39 | logit_mc = self.rc_mc(mc) # [B, 1] 40 | 41 | # topo 42 | logit_topo = self.embedding_topo(topo) # [B, embed_dim] 43 | logit_topo = self.rc_topo(logit_topo) # [B, 1] 44 | # ol 45 | logit_ol = self.embedding_ol(ol_pad) # [B, pad_len, embed_dim] 46 | packed_ol = pack_padded_sequence( 47 | logit_ol, ol_len, batch_first=True, enforce_sorted=False 48 | ) 49 | output_packed, hidden = self.rnn(packed_ol) # [B, pad_len, hid_dim] 50 | output_ol, len_ol = pad_packed_sequence(output_packed, batch_first=True) 51 | logit_ol = self.rc_ol(output_ol[:, -1, :]) # [B, pad_len, hid_dim] ->[B, 1] 52 | 53 | # total 54 | logit_total = torch.cat([logit_topo, logit_mc, logit_ol], dim=-1) 55 | logit_total = self.rc_total(logit_total) 56 | logit_total = self.sigmoid(logit_total) 57 | 58 | return logit_total 59 | -------------------------------------------------------------------------------- /mofreinforce/predictor/config_predictor.py: -------------------------------------------------------------------------------- 1 | import json 2 | from sacred import Experiment 3 | 4 | ex = Experiment("predictor") 5 | 6 | mc_to_idx = json.load(open("data/mc_to_idx.json")) 7 | topo_to_idx = json.load(open("data/topo_to_idx.json")) 8 | vocab_to_idx = json.load(open("data/vocab_to_idx.json")) # vocab for selfies 9 | 10 | 11 | def _loss_names(d): 12 | ret = { 13 | "regression": 0, # regression 14 | } 15 | ret.update(d) 16 | return ret 17 | 18 | 19 | @ex.config 20 | def config(): 21 | seed = 0 22 | exp_name = "predictor" 23 | 24 | dataset_dir = "###" 25 | loss_names = _loss_names({"regression": 1}) 26 | 27 | # model setting 28 | max_len = 128 # cls + mc + topo + ol_len 29 | vocab_dim = len(vocab_to_idx) 30 | mc_dim = len(mc_to_idx) 31 | topo_dim = len(topo_to_idx) 32 | weight_loss = None 33 | 34 | # transformer setting 35 | hid_dim = 256 36 | num_heads = 4 # hid_dim / 64 37 | num_layers = 4 38 | mlp_ratio = 4 39 | drop_rate = 0.1 40 | 41 | # run setting 42 | batch_size = 128 43 | per_gpu_batchsize = 64 44 | load_path = "" 45 | log_dir = "predictor/logs" 46 | num_workers = 8 # recommend num_gpus * 4 47 | num_nodes = 1 48 | devices = 2 49 | precision = 16 50 | 51 | # downstream 52 | downstream = "" 53 | n_classes = 0 54 | threshold_classification = None 55 | 56 | # Optimizer Setting 57 | optim_type = "adamw" # adamw, adam, sgd (momentum=0.9) 58 | learning_rate = 1e-4 59 | weight_decay = 1e-2 60 | decay_power = ( 61 | 1 # default polynomial decay, [cosine, constant, constant_with_warmup] 62 | ) 63 | max_epochs = 50 64 | max_steps = -1 # num_data * max_epoch // batch_size (accumulate_grad_batches) 65 | warmup_steps = 0.05 # int or float ( max_steps * warmup_steps) 66 | end_lr = 0 67 | lr_mult = 1 # multiply lr for downstream heads 68 | 69 | # trainer setting 70 | resume_from = None 71 | val_check_interval = 1.0 72 | test_only = False 73 | 74 | # normalize (when regression) 75 | mean = None 76 | std = None 77 | 78 | 79 | @ex.named_config 80 | def env_ifactor(): 81 | pass 82 | 83 | 84 | @ex.named_config 85 | def regression_qkh_round1(): 86 | exp_name = "regression_qkh_round1" 87 | dataset_dir = "data/dataset_predictor/qkh/round1" 88 | 89 | # trainer 90 | max_epochs = 50 91 | batch_size = 64 92 | per_gpu_batchsize = 16 93 | 94 | # normalize (when regression) 95 | mean = -20.331 96 | std = -10.383 97 | 98 | 99 | @ex.named_config 100 | def regression_qkh_round2(): 101 | exp_name = "regression_qkh_round2" 102 | dataset_dir = "data/dataset_predictor/qkh/round2" 103 | 104 | # trainer 105 | max_epochs = 50 106 | batch_size = 64 107 | per_gpu_batchsize = 2 108 | 109 | # normalize (when regression) 110 | mean = -21.068 111 | std = 10.950 112 | 113 | 114 | @ex.named_config 115 | def regression_qkh_round3(): 116 | exp_name = "regression_qkh_round3" 117 | dataset_dir = "data/dataset_predictor/qkh/round3" 118 | 119 | # trainer 120 | max_epochs = 50 121 | batch_size = 64 122 | per_gpu_batchsize = 2 123 | 124 | # normalize (when regression) 125 | mean = -21.810 126 | std = 11.452 127 | 128 | 129 | """ 130 | v1_selectivity 131 | """ 132 | 133 | 134 | @ex.named_config 135 | def regression_selectivity_round1(): 136 | exp_name = "regression_selectivity_round1" 137 | dataset_dir = "data/dataset_predictor/selectivity/round1" 138 | 139 | # trainer 140 | max_epochs = 50 141 | batch_size = 128 142 | per_gpu_batchsize = 16 143 | 144 | # normalize (when regression) 145 | mean = 1.872 146 | std = 1.922 147 | 148 | 149 | @ex.named_config 150 | def regression_selectivity_round2(): 151 | exp_name = "regression_selectivity_round2" 152 | dataset_dir = "data/dataset_predictor/selectivity/round2" 153 | 154 | # trainer 155 | max_epochs = 50 156 | batch_size = 128 157 | per_gpu_batchsize = 16 158 | 159 | # normalize (when regression) 160 | mean = 2.085 161 | std = 2.052 162 | 163 | 164 | @ex.named_config 165 | def regression_selectivity_round3(): 166 | exp_name = "regression_selectivity_round3" 167 | dataset_dir = "data/dataset_predictor/selectivity/round3" 168 | 169 | # trainer 170 | max_epochs = 50 171 | batch_size = 128 172 | per_gpu_batchsize = 16 173 | 174 | # normalize (when regression) 175 | mean = 2.258 176 | std = 2.128 177 | -------------------------------------------------------------------------------- /mofreinforce/predictor/config_predictor_ex.py: -------------------------------------------------------------------------------- 1 | import json 2 | from sacred import Experiment 3 | 4 | ex = Experiment("predictor") 5 | 6 | mc_to_idx = json.load(open("data/mc_to_idx.json")) 7 | topo_to_idx = json.load(open("data/topo_to_idx.json")) 8 | vocab_to_idx = json.load(open("data/vocab_to_idx.json")) # vocab for selfies 9 | 10 | 11 | def _loss_names(d): 12 | ret = { 13 | "regression": 0, # regression 14 | } 15 | ret.update(d) 16 | return ret 17 | 18 | 19 | @ex.config 20 | def config(): 21 | seed = 0 22 | exp_name = "predictor" 23 | 24 | dataset_dir = "###" 25 | loss_names = _loss_names({"regression": 1}) 26 | 27 | # model setting 28 | max_len = 128 # cls + mc + topo + ol_len 29 | vocab_dim = len(vocab_to_idx) 30 | mc_dim = len(mc_to_idx) 31 | topo_dim = len(topo_to_idx) 32 | weight_loss = None 33 | 34 | # transformer setting 35 | hid_dim = 256 36 | num_heads = 4 # hid_dim / 64 37 | num_layers = 4 38 | mlp_ratio = 4 39 | drop_rate = 0.1 40 | 41 | # run setting 42 | batch_size = 128 43 | per_gpu_batchsize = 64 44 | load_path = "" 45 | log_dir = "predictor/logs" 46 | num_workers = 8 # recommend num_gpus * 4 47 | num_nodes = 1 48 | devices = 2 49 | precision = 16 50 | 51 | # downstream 52 | downstream = "" 53 | n_classes = 0 54 | threshold_classification = None 55 | 56 | # Optimizer Setting 57 | optim_type = "adamw" # adamw, adam, sgd (momentum=0.9) 58 | learning_rate = 1e-4 59 | weight_decay = 1e-2 60 | decay_power = ( 61 | 1 # default polynomial decay, [cosine, constant, constant_with_warmup] 62 | ) 63 | max_epochs = 50 64 | max_steps = -1 # num_data * max_epoch // batch_size (accumulate_grad_batches) 65 | warmup_steps = 0.05 # int or float ( max_steps * warmup_steps) 66 | end_lr = 0 67 | lr_mult = 1 # multiply lr for downstream heads 68 | 69 | # trainer setting 70 | resume_from = None 71 | val_check_interval = 1.0 72 | test_only = False 73 | 74 | # normalize (when regression) 75 | mean = None 76 | std = None 77 | 78 | 79 | @ex.named_config 80 | def env_ifactor(): 81 | pass 82 | 83 | 84 | @ex.named_config 85 | def regression_vf(): 86 | exp_name = "regression_vf" 87 | dataset_dir = "data/dataset_predictor/vf" 88 | 89 | # trainer 90 | max_epochs = 50 91 | batch_size = 128 92 | per_gpu_batchsize = 16 93 | 94 | 95 | @ex.named_config 96 | def regression_qkh_old_round1(): 97 | exp_name = "regression_qkh_old_round1" 98 | dataset_dir = "data/dataset_predictor/qkh/old_round1" 99 | 100 | # trainer 101 | max_epochs = 50 102 | batch_size = 64 103 | per_gpu_batchsize = 16 104 | 105 | # normalize (when regression) 106 | mean = -19.408 107 | std = -9.172 108 | 109 | 110 | @ex.named_config 111 | def regression_qkh_old_round2(): 112 | exp_name = "regression_qkh_old_round2" 113 | dataset_dir = "data/dataset_predictor/qkh/old_round2" 114 | 115 | # trainer 116 | max_epochs = 50 117 | batch_size = 64 118 | per_gpu_batchsize = 16 119 | 120 | # normalize (when regression) 121 | mean = -19.886 122 | std = -9.811 123 | 124 | 125 | @ex.named_config 126 | def regression_qkh_new_round1(): 127 | exp_name = "regression_qkh_new_round1" 128 | dataset_dir = "data/dataset_predictor/qkh/new_round1" 129 | 130 | # trainer 131 | max_epochs = 50 132 | batch_size = 64 133 | per_gpu_batchsize = 16 134 | 135 | # normalize (when regression) 136 | mean = -20.331 137 | std = -10.383 138 | 139 | 140 | @ex.named_config 141 | def regression_qkh_new_round2(): 142 | exp_name = "regression_qkh_new_round2" 143 | dataset_dir = "data/dataset_predictor/qkh/new_round2" 144 | 145 | # trainer 146 | max_epochs = 50 147 | batch_size = 64 148 | per_gpu_batchsize = 2 149 | 150 | # normalize (when regression) 151 | mean = -21.068 152 | std = 10.950 153 | 154 | 155 | @ex.named_config 156 | def regression_qkh_new_round3(): 157 | exp_name = "regression_qkh_new_round3" 158 | dataset_dir = "data/dataset_predictor/qkh/new_round3" 159 | 160 | # trainer 161 | max_epochs = 50 162 | batch_size = 64 163 | per_gpu_batchsize = 2 164 | 165 | # normalize (when regression) 166 | mean = -21.810 167 | std = 11.452 168 | 169 | 170 | """ 171 | v1_selectivity 172 | """ 173 | 174 | 175 | @ex.named_config 176 | def regression_selectivity_new_round1(): 177 | exp_name = "regression_selectivity_new_round1" 178 | dataset_dir = "data/dataset_predictor/selectivity/new_round1" 179 | 180 | # trainer 181 | max_epochs = 50 182 | batch_size = 128 183 | per_gpu_batchsize = 16 184 | 185 | # normalize (when regression) 186 | mean = 1.872 187 | std = 1.922 188 | 189 | 190 | @ex.named_config 191 | def regression_selectivity_new_round2(): 192 | exp_name = "regression_selectivity_new_round2" 193 | dataset_dir = "data/dataset_predictor/selectivity/new_round2" 194 | 195 | # trainer 196 | max_epochs = 50 197 | batch_size = 128 198 | per_gpu_batchsize = 16 199 | 200 | # normalize (when regression) 201 | mean = 2.085 202 | std = 2.052 203 | 204 | 205 | @ex.named_config 206 | def regression_selectivity_new_round3(): 207 | exp_name = "regression_selectivity_new_round3" 208 | dataset_dir = "data/dataset_predictor/selectivity/new_round3" 209 | 210 | # trainer 211 | max_epochs = 50 212 | batch_size = 128 213 | per_gpu_batchsize = 16 214 | 215 | # normalize (when regression) 216 | mean = 2.258 217 | std = 2.128 218 | -------------------------------------------------------------------------------- /mofreinforce/predictor/datamodule.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from functools import partial 3 | 4 | from pytorch_lightning import LightningDataModule 5 | from torch.utils.data import DataLoader 6 | from predictor.dataset import MOFDataset 7 | 8 | 9 | class Datamodule(LightningDataModule): 10 | def __init__(self, _config): 11 | super().__init__() 12 | self.dataset_dir = _config["dataset_dir"] 13 | self.batch_size = _config["batch_size"] 14 | self.num_workers = _config["num_workers"] 15 | self.max_len = _config["max_len"] 16 | self.tasks = [k for k, v in _config["loss_names"].items() if v >= 1] 17 | 18 | @property 19 | def dataset_cls(self): 20 | return MOFDataset 21 | 22 | def set_train_dataset(self): 23 | self.train_dataset = self.dataset_cls( 24 | dataset_dir=self.dataset_dir, 25 | split="train", 26 | ) 27 | 28 | def set_val_dataset(self): 29 | self.val_dataset = self.dataset_cls( 30 | dataset_dir=self.dataset_dir, 31 | split="val", 32 | ) 33 | 34 | def set_test_dataset(self): 35 | self.test_dataset = self.dataset_cls( 36 | dataset_dir=self.dataset_dir, 37 | split="test", 38 | ) 39 | 40 | def setup(self, stage: Optional[str] = None): 41 | if stage in (None, "fit"): 42 | self.set_train_dataset() 43 | self.set_val_dataset() 44 | 45 | if stage in (None, "test"): 46 | self.set_test_dataset() 47 | 48 | self.collate = partial(self.dataset_cls.collate, max_len=self.max_len) 49 | 50 | def train_dataloader(self): 51 | return DataLoader( 52 | self.train_dataset, 53 | batch_size=self.batch_size, 54 | num_workers=self.num_workers, 55 | collate_fn=self.collate, 56 | drop_last=True, 57 | pin_memory=True, 58 | ) 59 | 60 | def val_dataloader(self): 61 | return DataLoader( 62 | self.val_dataset, 63 | batch_size=self.batch_size, 64 | num_workers=self.num_workers, 65 | collate_fn=self.collate, 66 | drop_last=True, 67 | pin_memory=True, 68 | ) 69 | 70 | def test_dataloader(self): 71 | return DataLoader( 72 | self.test_dataset, 73 | batch_size=self.batch_size, 74 | num_workers=self.num_workers, 75 | collate_fn=self.collate, 76 | drop_last=True, 77 | pin_memory=True, 78 | ) 79 | -------------------------------------------------------------------------------- /mofreinforce/predictor/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import json 3 | from pathlib import Path 4 | from torch.utils.data import Dataset 5 | 6 | 7 | class MOFDataset(Dataset): 8 | def __init__( 9 | self, 10 | dataset_dir, 11 | split, 12 | ): 13 | assert split in ["train", "test", "val"] 14 | 15 | # load dict_mof 16 | path_dict_mof = Path(dataset_dir, f"{split}.json") 17 | print(f"read file : {path_dict_mof}") 18 | self.dict_mof = json.load(open(path_dict_mof, "r")) 19 | self.mof_name = list(self.dict_mof.keys()) 20 | 21 | def __len__(self): 22 | return len(self.mof_name) 23 | 24 | def __getitem__(self, idx): 25 | ret = dict() 26 | mof_name = self.mof_name[idx] 27 | 28 | ret.update( 29 | { 30 | "mof_name": mof_name, 31 | } 32 | ) 33 | ret.update(self.dict_mof[mof_name]) 34 | return ret 35 | 36 | @staticmethod 37 | def collate(batch, max_len): 38 | keys = set([key for b in batch for key in b.keys()]) 39 | dict_batch = {k: [dic[k] if k in dic else None for dic in batch] for k in keys} 40 | # target 41 | dict_batch["target"] = torch.FloatTensor(dict_batch["target"]) 42 | # mc (idx) 43 | dict_batch["mc"] = torch.LongTensor(dict_batch["mc"]) 44 | # topo (idx) 45 | dict_batch["topo"] = torch.LongTensor(dict_batch["topo"]) 46 | # ol (selfies) 47 | ol_len = max_len - 3 # cls, mc, topo 48 | dict_batch["ol"] = torch.LongTensor( 49 | [ 50 | ol + [0] * (ol_len - len(ol)) if len(ol) < ol_len else ol[:ol_len] 51 | for ol in dict_batch["ol"] 52 | ] 53 | ) 54 | return dict_batch 55 | -------------------------------------------------------------------------------- /mofreinforce/predictor/gadgets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchmetrics import Metric 3 | 4 | 5 | class Accuracy(Metric): 6 | def __init__(self, dist_sync_on_step=False): 7 | super().__init__(dist_sync_on_step=dist_sync_on_step) 8 | self.add_state("correct", default=torch.tensor(0.0), dist_reduce_fx="sum") 9 | self.add_state("total", default=torch.tensor(0.0), dist_reduce_fx="sum") 10 | 11 | def update(self, logits, target): 12 | logits, target = ( 13 | logits.detach().to(self.correct.device), 14 | target.detach().to(self.correct.device), 15 | ) 16 | if len(logits.shape) > 1: 17 | preds = logits.argmax(dim=-1) 18 | else: 19 | # binary accuracy 20 | logits[logits >= 0.5] = 1 21 | logits[logits < 0.5] = 0 22 | preds = logits 23 | 24 | preds = preds[target != -100] 25 | target = target[target != -100] 26 | 27 | if target.numel() == 0: 28 | return 1 29 | 30 | assert preds.shape == target.shape 31 | 32 | self.correct += torch.sum(preds == target) 33 | self.total += target.numel() 34 | 35 | def compute(self): 36 | return self.correct / self.total 37 | 38 | 39 | class Scalar(Metric): 40 | def __init__(self, dist_sync_on_step=False): 41 | super().__init__(dist_sync_on_step=dist_sync_on_step) 42 | self.add_state("scalar", default=torch.tensor(0.0), dist_reduce_fx="sum") 43 | self.add_state("total", default=torch.tensor(0.0), dist_reduce_fx="sum") 44 | 45 | def update(self, scalar): 46 | if isinstance(scalar, torch.Tensor): 47 | scalar = scalar.detach().to(self.scalar.device) 48 | else: 49 | scalar = torch.tensor(scalar).float().to(self.scalar.device) 50 | self.scalar += scalar 51 | self.total += 1 52 | 53 | def compute(self): 54 | return self.scalar / self.total -------------------------------------------------------------------------------- /mofreinforce/predictor/heads.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class Pooler(nn.Module): 5 | def __init__(self, hidden_size, index=0): 6 | super().__init__() 7 | self.dense = nn.Linear(hidden_size, hidden_size) 8 | self.activation = nn.Tanh() 9 | self.index = index 10 | 11 | def forward(self, hidden_states): 12 | first_token_tensor = hidden_states[:, self.index] 13 | pooled_output = self.dense(first_token_tensor) 14 | pooled_output = self.activation(pooled_output) 15 | return pooled_output 16 | 17 | 18 | class RegressionHead(nn.Module): 19 | """ 20 | head for Regression 21 | """ 22 | 23 | def __init__(self, hid_dim): 24 | super().__init__() 25 | self.fc = nn.Linear(hid_dim, 1) 26 | 27 | def forward(self, x): 28 | x = self.fc(x) 29 | return x 30 | 31 | 32 | class ClassificationHead(nn.Module): 33 | """ 34 | head for Classification 35 | """ 36 | 37 | def __init__(self, hid_dim, n_classes): 38 | super().__init__() 39 | 40 | if n_classes == 2: 41 | self.fc = nn.Linear(hid_dim, 1) 42 | self.binary = True 43 | else: 44 | self.fc = nn.Linear(hid_dim, n_classes) 45 | self.binary = False 46 | 47 | def forward(self, x): 48 | x = self.fc(x) 49 | 50 | return x, self.binary 51 | -------------------------------------------------------------------------------- /mofreinforce/predictor/module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from pytorch_lightning import LightningModule 4 | 5 | from predictor import objectives, heads 6 | 7 | from predictor.transformer import Transformer 8 | from utils import module_utils 9 | 10 | from torchmetrics.functional import r2_score 11 | 12 | 13 | class Predictor(LightningModule): 14 | def __init__(self, config): 15 | super().__init__() 16 | self.save_hyperparameters() 17 | # build transformer 18 | self.transformer = Transformer( 19 | embed_dim=config["hid_dim"], 20 | depth=config["num_layers"], 21 | num_heads=config["num_heads"], 22 | mlp_ratio=config["mlp_ratio"], 23 | drop_rate=config["drop_rate"], 24 | ) 25 | 26 | # metal node embedding 27 | self.mc_embedding = nn.Embedding(config["mc_dim"], config["hid_dim"]) 28 | self.mc_embedding.apply(module_utils.init_weights) 29 | 30 | # topology embedding 31 | self.topo_embedding = nn.Embedding(config["topo_dim"], config["hid_dim"]) 32 | self.topo_embedding.apply(module_utils.init_weights) 33 | 34 | # organic linker embedding 35 | self.ol_embedding = nn.Embedding( 36 | config["vocab_dim"], config["hid_dim"], padding_idx=0 37 | ) 38 | self.ol_embedding.apply(module_utils.init_weights) 39 | 40 | # class token 41 | self.cls_embeddings = nn.Linear(1, config["hid_dim"]) 42 | self.cls_embeddings.apply(module_utils.init_weights) 43 | 44 | # position embedding 45 | # max_len = ol_max_len (100) + cls + mc + topo 46 | self.pos_embeddings = nn.Parameter( 47 | torch.zeros(1, config["max_len"], config["hid_dim"]) 48 | ) 49 | # self.pos_embeddings.apply(module_utils.init_weights) 50 | 51 | # pooler 52 | self.pooler = heads.Pooler(config["hid_dim"]) 53 | self.pooler.apply(module_utils.init_weights) 54 | 55 | # ===================== loss ===================== 56 | 57 | # regression 58 | if self.hparams.config["loss_names"]["regression"] > 0: 59 | self.regression_head = heads.RegressionHead(config["hid_dim"]) 60 | self.regression_head.apply(module_utils.init_weights) 61 | # normalization 62 | self.mean = config["mean"] 63 | self.std = config["std"] 64 | self.normalizer = module_utils.Normalizer(self.mean, self.std) 65 | 66 | module_utils.set_metrics(self) 67 | module_utils.set_task(self) 68 | # ===================== load downstream (test_only) ====================== 69 | 70 | if config["load_path"] != "" and config["test_only"]: 71 | ckpt = torch.load(config["load_path"], map_location="cpu") 72 | state_dict = ckpt["state_dict"] 73 | self.load_state_dict(state_dict, strict=False) 74 | print(f"load model : {config['load_path']}") 75 | 76 | def infer(self, batch): 77 | mc = batch["mc"] # [B] 78 | topo = batch["topo"] # [B] 79 | ol = batch["ol"] # [B, ol_len] 80 | batch_size = len(mc) 81 | 82 | mc_embeds = self.mc_embedding(mc).unsqueeze(1) # [B, 1, hid_dim] 83 | topo_embeds = self.topo_embedding(topo).unsqueeze(1) # [B, 1, hid_dim] 84 | 85 | ol_embeds = self.ol_embedding(ol) # [B, ol_len, hid_dim] 86 | 87 | cls_tokens = torch.zeros(batch_size).to(ol_embeds) # [B] 88 | cls_embeds = self.cls_embeddings(cls_tokens[:, None, None]) # [B, 1, hid_dim] 89 | 90 | # total_embedding and mask 91 | co_embeds = torch.cat( 92 | [cls_embeds, mc_embeds, topo_embeds, ol_embeds], dim=1 93 | ) # [B, max_len, hid_dim] 94 | co_masks = torch.cat( 95 | [torch.ones([batch_size, 3]).to(ol), (ol != 0).float()], dim=1 96 | ) 97 | 98 | # add pos_embeddings 99 | final_embeds = co_embeds + self.pos_embeddings 100 | final_embeds = self.transformer.pos_drop(final_embeds) 101 | 102 | # transformer blocks 103 | x = final_embeds 104 | for i, blk in enumerate(self.transformer.blocks): 105 | x, _attn = blk(x, mask=co_masks) 106 | 107 | x = self.transformer.norm(x) 108 | 109 | cls_feats = self.pooler(x) 110 | 111 | ret = { 112 | "topo_name": batch.get("topo_name", None), 113 | "mc_name": batch.get("mc_name", None), 114 | "ol_name": batch.get("ol_name", None), 115 | "cls_feats": cls_feats, 116 | "mc": mc, 117 | "topo": topo, 118 | "ol": ol, 119 | "output": x, 120 | "output_mask": co_masks, 121 | } 122 | return ret 123 | 124 | def forward(self, batch): 125 | ret = dict() 126 | 127 | if len(self.current_tasks) == 0: 128 | ret.update(self.infer(batch)) 129 | 130 | # regression 131 | if "regression" in self.current_tasks: 132 | ret.update(objectives.compute_regression(self, batch, self.normalizer)) 133 | 134 | return ret 135 | 136 | def training_step(self, batch, batch_idx): 137 | module_utils.set_task(self) 138 | output = self(batch) 139 | total_loss = sum([v for k, v in output.items() if "loss" in k]) 140 | 141 | return total_loss 142 | 143 | def training_epoch_end(self, outputs): 144 | module_utils.epoch_wrapup(self) 145 | 146 | def validation_step(self, batch, batch_idx): 147 | module_utils.set_task(self) 148 | output = self(batch) 149 | 150 | def validation_epoch_end(self, outputs): 151 | module_utils.epoch_wrapup(self) 152 | 153 | def test_step(self, batch, batch_idx): 154 | module_utils.set_task(self) 155 | output = self(batch) 156 | output = { 157 | k: (v.cpu() if torch.is_tensor(v) else v) for k, v in output.items() 158 | } # update cpu for memory 159 | return output 160 | 161 | def test_epoch_end(self, outputs): 162 | module_utils.epoch_wrapup(self) 163 | 164 | # calculate r2 score when regression 165 | if "regression_logits" in outputs[0].keys(): 166 | logits = [] 167 | labels = [] 168 | for out in outputs: 169 | logits += out["regression_logits"].tolist() 170 | labels += out["regression_labels"].tolist() 171 | r2 = r2_score(torch.FloatTensor(logits), torch.FloatTensor(labels)) 172 | self.log(f"test/r2_score", r2) 173 | 174 | """ 175 | # example for saving cls_feats 176 | import json 177 | d = {} 178 | for out in outputs: 179 | mc_name = out["mc_name"] 180 | ol_name = out["ol_name"] 181 | topo_name = out["topo_name"] 182 | cls_feats = out["cls_feats"] 183 | for mc, ol, topos, cls_feat in zip(mc_name, ol_name, topo_name, cls_feats): 184 | cif_id = "_".join([str(topos), str(mc), str(ol)]) 185 | d.update({cif_id : cls_feat.tolist()}) 186 | json.dump(d, open("cls_feats.json", "w")) 187 | """ 188 | 189 | def configure_optimizers(self): 190 | return module_utils.set_schedule(self) 191 | -------------------------------------------------------------------------------- /mofreinforce/predictor/objectives.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torchmetrics.functional import r2_score 4 | 5 | 6 | def weighted_mse_loss(logits, target, weight): 7 | return (weight * (logits - target) ** 2).mean() 8 | 9 | 10 | def compute_regression(pl_module, batch, normalizer): 11 | infer = pl_module.infer(batch) 12 | batch_size = pl_module.hparams.config["batch_size"] 13 | 14 | logits = pl_module.regression_head(infer["cls_feats"]).squeeze(-1) # [[B]] 15 | target = batch["target"] # [B] 16 | 17 | # normalize encode if config["mean"] and config["std], else pass 18 | target = normalizer.encode(target) 19 | 20 | loss = F.mse_loss(logits, target) 21 | 22 | ret = { 23 | "regression_loss": loss, 24 | "regression_logits": normalizer.decode(logits), 25 | "regression_labels": normalizer.decode(target), 26 | } 27 | ret.update(infer) 28 | 29 | # call update() loss and acc 30 | phase = "train" if pl_module.training else "val" 31 | loss = getattr(pl_module, f"{phase}_regression_loss")(ret["regression_loss"]) 32 | mae = getattr(pl_module, f"{phase}_regression_mae")( 33 | F.l1_loss(ret["regression_logits"], ret["regression_labels"]) 34 | ) 35 | 36 | r2 = getattr(pl_module, f"{phase}_regression_r2")(r2_score(logits, target)) 37 | 38 | pl_module.log( 39 | f"regression/{phase}/loss", loss, batch_size=batch_size, prog_bar=True 40 | ) 41 | pl_module.log(f"regression/{phase}/mae", mae, batch_size=batch_size, prog_bar=True) 42 | pl_module.log(f"regression/{phase}/r2", r2, batch_size=batch_size, prog_bar=True) 43 | return ret 44 | -------------------------------------------------------------------------------- /mofreinforce/predictor/transformer.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import torch 4 | import torch.nn as nn 5 | from timm.models.layers import DropPath, trunc_normal_ 6 | 7 | 8 | class Mlp(nn.Module): 9 | def __init__( 10 | self, 11 | in_features, 12 | hidden_features=None, 13 | out_features=None, 14 | act_layer=nn.GELU, 15 | drop=0.0, 16 | ): 17 | super().__init__() 18 | out_features = out_features or in_features 19 | hidden_features = hidden_features or in_features 20 | self.fc1 = nn.Linear(in_features, hidden_features) 21 | self.act = act_layer() 22 | self.fc2 = nn.Linear(hidden_features, out_features) 23 | self.drop = nn.Dropout(drop) 24 | 25 | def forward(self, x): 26 | x = self.fc1(x) 27 | x = self.act(x) 28 | x = self.drop(x) 29 | x = self.fc2(x) 30 | x = self.drop(x) 31 | return x 32 | 33 | 34 | class Attention(nn.Module): 35 | def __init__( 36 | self, 37 | dim, 38 | num_heads=8, 39 | qkv_bias=False, 40 | qk_scale=None, 41 | attn_drop=0.0, 42 | proj_drop=0.0, 43 | ): 44 | super().__init__() 45 | self.num_heads = num_heads 46 | head_dim = dim // num_heads 47 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 48 | self.scale = qk_scale or head_dim**-0.5 49 | 50 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 51 | self.attn_drop = nn.Dropout(attn_drop) 52 | self.proj = nn.Linear(dim, dim) 53 | self.proj_drop = nn.Dropout(proj_drop) 54 | 55 | def forward(self, x, mask=None): 56 | B, N, C = x.shape 57 | assert C % self.num_heads == 0 58 | qkv = ( 59 | self.qkv(x) # [B, N, 3*C] 60 | .reshape( 61 | B, N, 3, self.num_heads, C // self.num_heads 62 | ) # [B, N, 3, num_heads, C//num_heads] 63 | .permute(2, 0, 3, 1, 4) # [3, B, num_heads, N, C//num_heads] 64 | ) 65 | q, k, v = ( 66 | qkv[0], # [B, num_heads, N, C//num_heads] 67 | qkv[1], # [B, num_heads, N, C//num_heads] 68 | qkv[2], # [B, num_heads, N, C//num_heads] 69 | ) # make torchscript happy (cannot use tensor as tuple) 70 | 71 | attn = (q @ k.transpose(-2, -1)) * self.scale # [B, num_heads, N, N] 72 | if mask is not None: 73 | mask = mask.bool() 74 | attn = attn.masked_fill(~mask[:, None, None, :], float("-inf")) 75 | attn = attn.softmax(dim=-1) # [B, num_heads, N, N] 76 | attn = self.attn_drop(attn) 77 | 78 | x = ( 79 | (attn @ v).transpose(1, 2).reshape(B, N, C) 80 | ) # [B, num_heads, N, C//num_heads] -> [B, N, C] 81 | x = self.proj(x) 82 | x = self.proj_drop(x) 83 | return x, attn 84 | 85 | 86 | class Block(nn.Module): 87 | def __init__( 88 | self, 89 | dim, 90 | num_heads, 91 | mlp_ratio=4.0, 92 | qkv_bias=False, 93 | qk_scale=None, 94 | drop=0.0, 95 | attn_drop=0.0, 96 | drop_path=0.0, 97 | act_layer=nn.GELU, 98 | norm_layer=nn.LayerNorm, 99 | ): 100 | super().__init__() 101 | self.norm1 = norm_layer(dim) 102 | self.attn = Attention( 103 | dim, 104 | num_heads=num_heads, 105 | qkv_bias=qkv_bias, 106 | qk_scale=qk_scale, 107 | attn_drop=attn_drop, 108 | proj_drop=drop, 109 | ) 110 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 111 | self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 112 | self.norm2 = norm_layer(dim) 113 | mlp_hidden_dim = int(dim * mlp_ratio) 114 | self.mlp = Mlp( 115 | in_features=dim, 116 | hidden_features=mlp_hidden_dim, 117 | act_layer=act_layer, 118 | drop=drop, 119 | ) 120 | 121 | def forward(self, x, mask=None): 122 | _x, attn = self.attn(self.norm1(x), mask=mask) 123 | x = x + self.drop_path(_x) 124 | x = x + self.drop_path(self.mlp(self.norm2(x))) 125 | return x, attn 126 | 127 | 128 | class Transformer(nn.Module): 129 | def __init__( 130 | self, 131 | embed_dim, 132 | depth=12, 133 | num_heads=12, 134 | mlp_ratio=4.0, 135 | qkv_bias=True, 136 | qk_scale=None, 137 | drop_rate=0.0, 138 | attn_drop_rate=0.0, 139 | drop_path_rate=0.0, 140 | norm_layer=None, 141 | add_norm_before_transformer=False, 142 | config=None, 143 | ): 144 | super().__init__() 145 | 146 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) 147 | self.add_norm_before_transformer = add_norm_before_transformer 148 | if add_norm_before_transformer: 149 | self.pre_norm = norm_layer(embed_dim) 150 | 151 | dpr = [ 152 | x.item() for x in torch.linspace(0, drop_path_rate, depth) 153 | ] # stochastic depth decay rule 154 | self.blocks = nn.ModuleList( 155 | [ 156 | Block( 157 | dim=embed_dim, 158 | num_heads=num_heads, 159 | mlp_ratio=mlp_ratio, 160 | qkv_bias=qkv_bias, 161 | qk_scale=qk_scale, 162 | drop=drop_rate, 163 | attn_drop=attn_drop_rate, 164 | drop_path=dpr[i], 165 | norm_layer=norm_layer, 166 | ) 167 | for i in range(depth) 168 | ] 169 | ) 170 | self.norm = norm_layer(embed_dim) 171 | 172 | # trunc_normal_(self.cls_token, std=0.02) 173 | self.pos_drop = nn.Dropout(p=drop_rate) 174 | 175 | self.apply(self._init_weights) 176 | 177 | def _init_weights(self, m): 178 | if isinstance(m, nn.Linear): 179 | trunc_normal_(m.weight, std=0.02) 180 | if isinstance(m, nn.Linear) and m.bias is not None: 181 | nn.init.constant_(m.bias, 0) 182 | elif isinstance(m, nn.LayerNorm): 183 | nn.init.constant_(m.bias, 0) 184 | nn.init.constant_(m.weight, 1.0) 185 | -------------------------------------------------------------------------------- /mofreinforce/reinforce/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hspark1212/MOFreinforce/b920084a29db9223482ddfeffd2e775b82f49e63/mofreinforce/reinforce/__init__.py -------------------------------------------------------------------------------- /mofreinforce/reinforce/config_reinforce.py: -------------------------------------------------------------------------------- 1 | from sacred import Experiment 2 | 3 | ex = Experiment("reinforce") 4 | 5 | 6 | @ex.config 7 | def config(): 8 | seed = 0 9 | exp_name = "" 10 | gpu_idx = 0 11 | 12 | # load predictor 13 | predictor_load_path = ["model/predictor_qkh.ckpt"] 14 | mean = [None] 15 | std = [None] 16 | 17 | # get reward 18 | threshold = False 19 | """ 20 | (1) threshold == True 21 | if |pred| >= |reward_max|, reward = 1, else : reward = |pred| / |reward_max| 22 | (2) threshold == False 23 | reward = |pred| / |reward_max| 24 | """ 25 | reward_max = [1.0] # if you want to normalize by max reward. 26 | 27 | # load generator 28 | dataset_dir = "data/dataset_generator" 29 | generator_load_path = "model/generator/generator.ckpt" 30 | 31 | # REINFORCE 32 | lr = 1e-4 33 | decay_ol = 1.0 34 | decay_topo = 1.0 35 | scheduler = "constant" # warmup = get_linear_schedule_with_warmup, constant = get_constant_schedule 36 | early_stop = 0.5 # early_stop when the accuracy of scaffold is less than it. 37 | ratio_exploit = 0.0 # ratio for exploitation (freeze) 38 | """ to decrease effect of topology when calculating loss. 39 | Given topology is the discrete feature, its loss could much larger than the sum of organic linker. 40 | It gives rise to training failure. 41 | """ 42 | ratio_mask_mc = 0.0 # ratio for masking mc of input_src 43 | 44 | # Trainer 45 | max_epochs = 20 46 | batch_size = 16 47 | accumulate_grad_batches = 2 48 | devices = 1 49 | num_nodes = 1 50 | precision = 16 51 | resume_from = None 52 | limit_train_batches = 500 53 | limit_val_batches = 30 54 | val_check_interval = 1.0 55 | test_only = False 56 | load_path = "" 57 | gradient_clip_val = None 58 | 59 | # tensorboard 60 | log_dir = "reinforce/logs" 61 | 62 | 63 | @ex.named_config 64 | def test(): 65 | exp_name = "test" 66 | 67 | 68 | """ 69 | v0 Q_kH 70 | """ 71 | 72 | 73 | @ex.named_config 74 | def v0_scratch(): 75 | exp_name = "v0_scratch" 76 | test_only = True 77 | 78 | # reward 79 | reward_max = [-60.0] 80 | 81 | # predictor 82 | predictor_load_path = ["model/predictor/best_predictor_v0_qkh_round1.ckpt"] 83 | mean = [-20.331] 84 | std = [-10.383] 85 | 86 | 87 | @ex.named_config 88 | def v0_qkh_round1(): 89 | """ 90 | omit mc in the input 91 | """ 92 | exp_name = "v0_qkh_round1" 93 | max_epochs = 20 94 | 95 | # reward 96 | reward_max = [-60.0] 97 | 98 | # reinforce 99 | early_stop = 0.5 # early_stop when the accuracy of scaffold is less than it. 100 | ratio_exploit = 0.5 # ratio for exploitation 101 | ratio_mask_mc = 0.5 # ratio for masking mc of input_src 102 | 103 | # predictor 104 | predictor_load_path = ["model/predictor/best_predictor_v0_qkh_round1.ckpt"] 105 | mean = [-20.331] 106 | std = [-10.383] 107 | 108 | 109 | @ex.named_config 110 | def v0_qkh_round2(): 111 | """ 112 | omit mc in the input 113 | """ 114 | exp_name = "v0_qkh_round2" 115 | max_epochs = 20 116 | 117 | # reward 118 | reward_max = [-60.0] 119 | 120 | # reinforce 121 | early_stop = 0.5 # early_stop when the accuracy of scaffold is less than it. 122 | ratio_exploit = 0.5 # ratio for exploitation 123 | ratio_mask_mc = 0.5 # ratio for masking mc of input_src 124 | 125 | # predictor 126 | predictor_load_path = ["model/predictor/best_predictor_v0_qkh_round2.ckpt"] 127 | mean = [-21.068] 128 | std = [10.950] 129 | 130 | 131 | @ex.named_config 132 | def v0_qkh_round3(): 133 | """ 134 | omit mc in the input 135 | """ 136 | exp_name = "v0_qkh_round3" 137 | max_epochs = 20 138 | 139 | # reward 140 | reward_max = [-60.0] 141 | 142 | # reinforce 143 | early_stop = 0.5 # early_stop when the accuracy of scaffold is less than it. 144 | ratio_exploit = 0.5 # ratio for exploitation 145 | ratio_mask_mc = 0.5 # ratio for masking mc of input_src 146 | 147 | # predictor 148 | predictor_load_path = ["model/predictor/best_predictor_v0_qkh_round3.ckpt"] 149 | mean = [-21.810] 150 | std = [11.452] 151 | 152 | 153 | """ 154 | Selectivity (v1) 155 | """ 156 | 157 | 158 | @ex.named_config 159 | def v1_scratch(): 160 | exp_name = "v1_scratch" 161 | test_only = True 162 | 163 | # reward 164 | reward_max = [10.0] 165 | 166 | # predictor 167 | predictor_load_path = ["model/predictor/best_predictor_v1_selectivity_round1.ckpt"] 168 | mean = [1.871] 169 | std = [1.922] 170 | 171 | 172 | @ex.named_config 173 | def v1_selectivity_round1(): 174 | exp_name = "v1_selectivity_round1" 175 | max_epochs = 20 176 | 177 | # reward 178 | reward_max = [10.0] 179 | 180 | # reinforce 181 | early_stop = 0.5 # early_stop when the accuracy of scaffold is less than it. 182 | ratio_exploit = 0.5 # ratio for exploitation 183 | ratio_mask_mc = 0.5 # ratio for masking mc of input_src 184 | 185 | # predictor 186 | predictor_load_path = ["model/predictor/best_predictor_v1_selectivity_round1.ckpt"] 187 | mean = [1.871] 188 | std = [1.922] 189 | 190 | 191 | @ex.named_config 192 | def v1_selectivity_round2(): 193 | exp_name = "v1_selectivity_round2" 194 | max_epochs = 20 195 | 196 | # reward 197 | reward_max = [10.0] 198 | 199 | # reinforce 200 | early_stop = 0.5 # early_stop when the accuracy of scaffold is less than it. 201 | ratio_exploit = 0.5 # ratio for exploitation 202 | ratio_mask_mc = 0.5 # ratio for masking mc of input_src 203 | 204 | # predictor 205 | predictor_load_path = ["model/predictor/best_predictor_v1_selectivity_round2.ckpt"] 206 | mean = [2.085] 207 | std = [2.052] 208 | 209 | 210 | @ex.named_config 211 | def v1_selectivity_round3(): 212 | exp_name = "v1_selectivity_round3" 213 | max_epochs = 20 214 | 215 | # reward 216 | reward_max = [10.0] 217 | 218 | # reinforce 219 | early_stop = 0.5 # early_stop when the accuracy of scaffold is less than it. 220 | ratio_exploit = 0.5 # ratio for exploitation 221 | ratio_mask_mc = 0.5 # ratio for masking mc of input_src 222 | 223 | # predictor 224 | predictor_load_path = ["model/predictor/best_predictor_v1_selectivity_round3.ckpt"] 225 | mean = [2.258] 226 | std = [2.128] 227 | -------------------------------------------------------------------------------- /mofreinforce/reinforce/module.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import copy 4 | import torch 5 | from tqdm import tqdm 6 | 7 | from utils.metrics import Metrics 8 | 9 | from pytorch_lightning import LightningModule 10 | from transformers import get_linear_schedule_with_warmup, get_constant_schedule 11 | 12 | from rdkit import Chem 13 | from rdkit.Chem.Draw import MolToImage 14 | 15 | import numpy as np 16 | import libs.selfies as sf 17 | 18 | topo_to_cn = json.load(open("data/final_topo_cn.json")) 19 | mc_to_cn = json.load(open("data/mc_cn.json")) 20 | 21 | 22 | class Reinforce(LightningModule): 23 | def __init__(self, agent, predictors, config): 24 | super(Reinforce, self).__init__() 25 | 26 | self.agent = copy.deepcopy(agent) 27 | # freeze only encoder 28 | for param in self.agent.transformer.encoder.parameters(): 29 | param.requires_grad = False 30 | 31 | self.freeze_agent = copy.deepcopy(agent) 32 | # freeze agent 33 | for param in self.freeze_agent.parameters(): 34 | param.requires_grad = False 35 | 36 | self.predictors = predictors 37 | 38 | # reward 39 | self.threshold = config["threshold"] 40 | self.reward_max = config["reward_max"] 41 | 42 | # reinforce 43 | self.lr = config["lr"] 44 | self.decay_ol = config["decay_ol"] 45 | self.scheduler = config["scheduler"] 46 | self.decay_topo = config["decay_topo"] 47 | self.early_stop = config["early_stop"] 48 | self.ratio_exploit = config["ratio_exploit"] 49 | self.ratio_mask_mc = config["ratio_mask_mc"] 50 | 51 | # load model 52 | if config["load_path"] != "": 53 | ckpt = torch.load(config["load_path"], map_location="cpu") 54 | state_dict = ckpt["state_dict"] 55 | self.load_state_dict(state_dict, strict=False) 56 | print(f"load model : {config['load_path']}") 57 | 58 | def forward(self, src, max_len=128): 59 | """ 60 | based on evaluate() in Generator 61 | :param src: encoded_input of generator [B=1, seq_len] 62 | :return: 63 | """ 64 | 65 | vocab_to_idx = self.agent.vocab_to_idx 66 | 67 | # mask_mc of encoded_src 68 | if torch.randn(1) < self.ratio_mask_mc: 69 | src[:, 0].fill_(0) # 0 is pad_idx 70 | 71 | src_mask = self.agent.transformer.make_src_mask(src) 72 | 73 | enc_src = self.agent.transformer.encoder(src, src_mask) # [B=1, seq_len, hid_dim] 74 | 75 | # get target 76 | trg_indexes = [vocab_to_idx["[SOS]"]] 77 | 78 | list_prob = [] 79 | 80 | for i in range(max_len): 81 | 82 | trg_tensor = torch.LongTensor(trg_indexes).unsqueeze(0).to(src.device) 83 | trg_mask = self.agent.transformer.make_trg_mask(trg_tensor) # [B=1, 1, seq_len, seq_len] 84 | 85 | output = self.agent.transformer.decoder( 86 | trg_tensor, enc_src, trg_mask, src_mask) # [B=1, seq_len, vocab_dim] 87 | 88 | with torch.no_grad(): 89 | freeze_output = self.freeze_agent.transformer.decoder( 90 | trg_tensor, enc_src, trg_mask, src_mask) # [B=1, seq_len, vocab_dim] 91 | if torch.rand(1)[0] < self.ratio_exploit: 92 | output = freeze_output 93 | 94 | if "output_ol" in output.keys(): 95 | out = output["output_ol"] 96 | prob = out.softmax(dim=-1)[0, -1] 97 | m = torch.multinomial(prob, 1).item() 98 | elif "output_mc" in output.keys(): 99 | out = output["output_mc"] 100 | prob = out.softmax(dim=-1)[0] 101 | m = torch.multinomial(prob, 1).item() 102 | else: 103 | out = output["output_topo"] 104 | prob = out.softmax(dim=-1)[0] 105 | m = torch.multinomial(prob, 1).item() 106 | 107 | list_prob.append(prob[m]) 108 | trg_indexes.append(m) 109 | 110 | if i >= 2 and m == vocab_to_idx["[EOS]"]: 111 | break 112 | 113 | # get topo 114 | topo_idx = trg_indexes[1] 115 | topo = self.agent.idx_to_topo[topo_idx] 116 | # get mc 117 | mc_idx = trg_indexes[2] 118 | mc = self.agent.idx_to_mc[mc_idx] 119 | # get ol 120 | ol_idx = trg_indexes[3:] 121 | ol_tokens = [self.agent.idx_to_vocab[idx] for idx in ol_idx] 122 | 123 | # check matching connection points 124 | topo_cn = topo_to_cn.get(topo, [0]) # if topo is [PAD] then topo_cn = [0] 125 | if len(topo_cn) == 1: 126 | topo_cn.append(2) 127 | mc_cn = mc_to_cn.get(mc, -1) # if mc is [PAD] then mc_cn = -1 128 | ol_cn = ol_idx.count(self.agent.vocab_to_idx["[*]"]) 129 | 130 | # convert to selfies and smiles 131 | try: 132 | gen_sf = "".join(ol_tokens[:-1]) # remove EOS token 133 | gen_sm = sf.decoder(gen_sf) 134 | new_gen_sf = sf.encoder(gen_sm) # to check valid SELFIES 135 | assert ol_tokens[-1] == "[EOS]", Exception("The last token is not [EOS]") 136 | assert new_gen_sf == gen_sf, Exception("SELFIES error") 137 | assert set(topo_cn) == set([mc_cn, ol_cn]), Exception("connection points error") 138 | except Exception as e: 139 | print(e) 140 | # print(f"The failed gen_sf : {gen_sf}") 141 | # print(f"The failed gen_sm : {gen_sm}") 142 | gen_sf = None 143 | gen_sm = None 144 | 145 | ret = { 146 | "topo": topo, 147 | "mc": mc, 148 | "topo_idx": topo_idx, 149 | "mc_idx": mc_idx, 150 | "ol_idx": ol_idx, 151 | "gen_sf": gen_sf, 152 | "gen_sm": gen_sm, 153 | "list_prob": list_prob, 154 | } 155 | return ret 156 | 157 | def configure_optimizers(self): 158 | optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) 159 | max_steps = self.trainer.estimated_stepping_batches 160 | print(f"mas_steps : {max_steps}") 161 | 162 | if self.scheduler == "warmup": 163 | scheduler = get_linear_schedule_with_warmup( 164 | optimizer=optimizer, 165 | num_warmup_steps=int(self.trainer.max_epochs * 0.1), 166 | num_training_steps=self.trainer.max_epochs, 167 | ) 168 | elif self.scheduler == "constant": 169 | scheduler = get_constant_schedule( 170 | optimizer=optimizer, 171 | ) 172 | return [optimizer], [scheduler] 173 | 174 | def training_step(self, batch, batch_idx): 175 | self.agent.train() 176 | 177 | rl_loss = 0 178 | total_reward = 0 179 | num_fail = 0 180 | 181 | for enc_input in batch["encoded_input"]: 182 | src = enc_input.unsqueeze(0) 183 | out = self(src) 184 | if out["gen_sm"] is None: 185 | num_fail += 1 186 | continue 187 | 188 | # get reward from predictor 189 | with torch.no_grad(): 190 | rewards, preds = self.get_reward(output_generator=out) 191 | reward = sum(rewards) 192 | 193 | total_reward += reward 194 | 195 | # Reinforce algorithm 196 | discounted_reward = reward 197 | tmp_loss = [] 198 | 199 | # topology, metal cluster (the probability of metal cluster is almost zero, because it is used as input) 200 | for prob in out["list_prob"][:2]: 201 | rl_loss -= torch.log(prob) * reward * self.decay_topo 202 | tmp_loss.append(-torch.log(prob).item() * reward) 203 | # organic linker (discounter_reward is applied only for organic linkers) 204 | for i, prob in enumerate(out["list_prob"][2:][::-1]): 205 | rl_loss -= torch.log(prob) * discounted_reward 206 | discounted_reward = discounted_reward * self.decay_ol 207 | tmp_loss.append(-torch.log(prob).item() * discounted_reward) 208 | 209 | print(rewards, preds, out["topo"], out["mc"], out["gen_sm"], out["gen_sf"], src[0, 1].item()) 210 | print(rl_loss, tmp_loss[:2], sum(tmp_loss[2:])) 211 | 212 | n_batch = batch["encoded_input"].shape[0] 213 | rl_loss = rl_loss / n_batch 214 | total_reward = total_reward / n_batch 215 | num_fail = torch.tensor(num_fail) / n_batch 216 | 217 | self.log("train/rl_loss", rl_loss) 218 | self.log("train/reward", total_reward, prog_bar=True) 219 | self.log("train/num_fail", num_fail, prog_bar=True) 220 | 221 | return rl_loss 222 | 223 | def validation_step(self, batch, batch_idx): 224 | return batch 225 | 226 | def validation_epoch_end(self, batches): 227 | list_src = torch.concat([b["encoded_input"] for b in batches], dim=0) 228 | metrics = Metrics(self.agent.vocab_to_idx, self.agent.idx_to_vocab) 229 | split = "val" 230 | 231 | metrics = self.update_metrics(list_src, metrics, split) 232 | if metrics.get_mean(metrics.scaffold) < self.early_stop: 233 | raise Exception(f"EarlyStopping for scaffold : scaffold accuracy is less than {self.early_stop}") 234 | 235 | def test_step(self, batch, batch_idx): 236 | return batch 237 | 238 | def test_epoch_end(self, batches): 239 | list_src = torch.concat([b["encoded_input"] for b in batches], dim=0) 240 | metrics = Metrics(self.agent.vocab_to_idx, self.agent.idx_to_vocab) 241 | split = "test" 242 | 243 | metrics = self.update_metrics(list_src, metrics, split) 244 | # save results 245 | ret = { 246 | "rewards": metrics.rewards, 247 | "preds": metrics.preds, 248 | "gen_sms": metrics.gen_ol, 249 | "gen_mcs": metrics.gen_mc, 250 | "gen_topos": metrics.gen_topo, 251 | } 252 | path = os.path.join(self.logger.save_dir, f"results_{self.logger.name}.json") 253 | json.dump(ret, open(path, "w")) 254 | 255 | def get_reward(self, 256 | output_generator, 257 | sos_idx=1, 258 | pad_idx=0, 259 | max_len=128, 260 | ): 261 | # create the input of predictor 262 | device = self.predictors[0].device 263 | batch = {} 264 | batch["mc"] = torch.LongTensor([output_generator["mc_idx"]]).to(device) 265 | batch["topo"] = torch.LongTensor([output_generator["topo_idx"]]).to(device) 266 | batch["ol"] = torch.LongTensor([[sos_idx] + # add sos_idx 267 | output_generator["ol_idx"][:max_len - 4] + 268 | [pad_idx] * (max_len - 4 - len(output_generator["ol_idx"]))]).to( 269 | device) # add pad_idx 270 | 271 | preds = [] 272 | rewards = [] 273 | for i, predictor in enumerate(self.predictors): 274 | infer = predictor.infer(batch) 275 | p = predictor.regression_head(infer["cls_feats"]) 276 | p = predictor.normalizer.decode(p) 277 | p = p.detach().item() 278 | preds.append(p) 279 | 280 | # get reward 281 | if self.threshold: 282 | if abs(p) >= abs(self.reward_max[i]): 283 | r = 1 284 | else: 285 | r = 0 286 | rewards.append(r) 287 | else: 288 | rewards.append(p / self.reward_max[i]) 289 | 290 | if self.threshold: 291 | if all(rewards): 292 | rewards = [1] * len(rewards) 293 | else: 294 | rewards = [0] * len(rewards) 295 | return rewards, preds 296 | 297 | def update_metrics(self, list_src, metrics, split, num_log_mols=16): 298 | """ 299 | 300 | :param list_src: list of source 301 | :param metrics: dictionary of metrics 302 | :param split: str, split 303 | :param num_log_mols: number of molecules for log, which will be saved in tensorboard 304 | :return: 305 | :rtype: 306 | """ 307 | for src in tqdm(list_src): 308 | out = self(src.unsqueeze(0)) 309 | 310 | if out["gen_sm"] is None: 311 | metrics.num_fail.append(1) 312 | continue 313 | else: 314 | metrics.num_fail.append(0) 315 | 316 | with torch.no_grad(): 317 | rewards, preds = self.get_reward(output_generator=out) 318 | 319 | metrics.update(out, src, rewards=rewards, preds=preds) 320 | 321 | self.log(f"{split}/conn_match", metrics.get_mean(metrics.conn_match)) 322 | self.log(f"{split}/unique_ol", len(set(metrics.gen_ol)) / len(metrics.gen_ol)) 323 | self.log(f"{split}/unique_topo_mc", len(set(zip(metrics.gen_topo, metrics.gen_mc))) / len(metrics.gen_topo)) 324 | self.log(f"{split}/scaffold", metrics.get_mean(metrics.scaffold)) 325 | 326 | self.log(f"{split}/total_reward", metrics.get_mean(metrics.rewards)) 327 | for i, reward in enumerate(zip(*metrics.rewards)): 328 | self.log(f"{split}/reward_{i}", metrics.get_mean(reward)) 329 | for i, pred in enumerate(zip(*metrics.preds)): 330 | self.log(f"{split}/target_{i}", metrics.get_mean(pred)) 331 | self.log(f"{split}/num_fail", metrics.get_mean(metrics.num_fail)) 332 | 333 | # add image to log 334 | for i in range(num_log_mols): 335 | ol = metrics.gen_ol[i] 336 | frags = metrics.input_frags[i] 337 | imgs = [] 338 | for s in [ol] + frags: 339 | m = Chem.MolFromSmiles(s) 340 | if not m: 341 | continue 342 | img = MolToImage(m) 343 | img = np.array(img) 344 | img = torch.tensor(img) 345 | imgs.append(img) 346 | imgs = np.stack(imgs, axis=0) 347 | self.logger.experiment.add_image(f"{split}/{i}", imgs, self.global_step, dataformats="NHWC") 348 | 349 | # total gen_ol 350 | imgs = [] 351 | for i, m in enumerate(metrics.gen_ol[:num_log_mols]): 352 | try: 353 | m = Chem.MolFromSmiles(m) 354 | img = MolToImage(m) 355 | img = np.array(img) 356 | img = torch.tensor(img) 357 | imgs.append(img) 358 | except Exception as e: 359 | print(e) 360 | imgs = np.stack(imgs, axis=0) 361 | self.logger.experiment.add_image(f"{split}/gen_ol/", imgs, self.global_step, dataformats="NHWC") 362 | 363 | return metrics 364 | -------------------------------------------------------------------------------- /mofreinforce/reinforce/reinforce.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from functools import partial 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch.distributions import Categorical 9 | from torch.utils.tensorboard.writer import SummaryWriter 10 | 11 | from rdkit import Chem 12 | from rdkit.Chem.Draw import MolToImage 13 | 14 | import matplotlib.pyplot as plt 15 | import seaborn as sns 16 | import numpy as np 17 | import libs.selfies as sf 18 | 19 | class Reinforce(object): 20 | def __init__(self, 21 | generator, 22 | rmsd_predictor, 23 | target_predictor, 24 | get_reward_rmsd, 25 | get_reward_target, 26 | vocab_to_idx, 27 | mc_to_idx, 28 | topo_to_idx, 29 | emb_dim=128, 30 | hid_dim=128, 31 | config=None, 32 | ): 33 | """ 34 | REINFORCE algorithm to generate MOFs to maximize reward functions (rsmsd and target) 35 | 36 | Parameters; 37 | generator_v0 (nn.Module): generator_v0 38 | rmsd_predictor (nn.Module): predictor for rmsd 39 | target_predictor (nn.Module): predictor for target 40 | get_reward_rmsd (function): get rewards for rmsd 41 | get_rewad_target (function): get rewards for target 42 | vocab_to_idx (dictionary) dictionary for token of selfies with index 43 | mc_to_idx (dictionary): dictionary for metal cluster with index 44 | topo_to_idx (dictionary): dictionary for topology with index 45 | emb_dim (int): dimension of embedding for metal cluster and topology 46 | hid_dim (int): dimension of hidden for metal culster and topology 47 | """ 48 | super(Reinforce, self).__init__() 49 | # save config 50 | if config is not None: 51 | json.dump(config, open(f"{self.logger.save_dir}/{self.logger.name}/hparams.json", "w")) 52 | 53 | self.generator = generator 54 | self.rmsd_predictor = rmsd_predictor 55 | self.target_predictor = target_predictor 56 | self.get_reward_rmsd = partial(get_reward_rmsd, 57 | criterion=config["criterion_rmsd"], 58 | reward_positive=config["reward_positive_rmsd"], 59 | reward_negative=config["reward_negative_rmsd"], 60 | ) 61 | self.get_reward_target = partial(get_reward_target, 62 | criterion=config["criterion_target"], 63 | reward_positive=config["reward_positive_target"], 64 | reward_negative=config["reward_negative_target"], 65 | ) 66 | 67 | self.vocab_to_idx = vocab_to_idx 68 | self.mc_to_idx = mc_to_idx 69 | self.topo_to_idx = topo_to_idx 70 | self.emb_dim = emb_dim 71 | self.hid_dim = hid_dim 72 | 73 | self.rmsd_predictor.eval() 74 | self.target_predictor.eval() 75 | self.upgrade_generator() 76 | 77 | self.log_dir = f"{config['log_dir']}/{config['exp_name']}" 78 | if os.path.exists(self.log_dir) and config["test_only"] is not True: 79 | raise Exception(f"{self.log_dir} is already exist") 80 | 81 | self.writer = SummaryWriter(log_dir=self.log_dir) 82 | 83 | def upgrade_generator(self): 84 | """ 85 | add metal clutser and topology to generator_v0 86 | """ 87 | self.generator.embedding_topo = nn.Embedding(len(self.topo_to_idx), self.emb_dim) 88 | self.generator.rc_mc_1 = nn.Linear(self.emb_dim, self.hid_dim) 89 | self.generator.rc_mc_2 = nn.Linear(self.hid_dim, len(self.mc_to_idx)) 90 | 91 | self.generator.optimizer = self.generator.optimizer_instance(self.generator.parameters(), 92 | lr=1e-3, 93 | weight_decay=0.00001) 94 | self.generator = self.generator.cuda() 95 | 96 | def load_model(self, load_filename): 97 | self.generator.load_state_dict(torch.load(load_filename)) 98 | 99 | def generate_mc(self, topo, return_loss=False): 100 | emb_topo = self.generator.embedding_topo(topo) 101 | logits = F.relu(self.generator.rc_mc_1(emb_topo)) 102 | pred_mc = F.softmax(self.generator.rc_mc_2(logits), dim=-1) 103 | m = Categorical(pred_mc) 104 | mc = m.sample() 105 | loss_mc = torch.log(pred_mc[mc]) 106 | 107 | if return_loss: 108 | return mc, loss_mc 109 | else: 110 | return mc 111 | 112 | def generate_ol(self): 113 | num_fail = 0 114 | # generate smiles 115 | valid_smiles = False 116 | while not valid_smiles: 117 | gen_sf = self.generator.evaluate() 118 | try: 119 | gen_sm = sf.decoder(gen_sf[5:-5]) 120 | m = Chem.MolFromSmiles(gen_sm) 121 | except Exception: 122 | num_fail += 1 123 | continue 124 | 125 | # encodig_sf 126 | encoded_sf = sf.selfies_to_encoding(selfies=gen_sf, 127 | vocab_stoi=self.vocab_to_idx, 128 | enc_type="label") 129 | 130 | if m and 10 < len(encoded_sf) <= 100: 131 | valid_smiles = True 132 | else: 133 | num_fail += 1 134 | 135 | return encoded_sf, gen_sm, num_fail 136 | 137 | def init_generator(self): 138 | """ 139 | Initialize stackRNN 140 | """ 141 | 142 | hidden = self.generator.init_hidden() 143 | if self.generator.has_cell: 144 | cell = self.generator.init_cell() 145 | hidden = (hidden, cell) 146 | if self.generator.has_stack: 147 | stack = self.generator.init_stack() 148 | else: 149 | stack = None 150 | 151 | return hidden, stack 152 | 153 | def policy_gradient(self, n_batch, gamma, topo_idx=-1): 154 | """ 155 | REINFORCE algorithm 156 | """ 157 | self.generator.train() 158 | 159 | rl_loss = 0 160 | self.generator.optimizer.zero_grad() 161 | total_reward = 0 162 | num_fail = 0 163 | 164 | for n_epi in range(n_batch): 165 | # topology 166 | if topo_idx < 0: 167 | topo = torch.randint(0, len(self.topo_to_idx), size=(1,))[0].cuda() 168 | else: 169 | topo = torch.tensor(topo_idx).cuda() 170 | 171 | # metal cluster 172 | mc, loss_mc = self.generate_mc(topo, return_loss=True) 173 | rl_loss -= loss_mc 174 | 175 | # organic linker 176 | encoded_sf, _, n_f = self.generate_ol() 177 | num_fail += n_f 178 | 179 | # rmsd reward 180 | reward_rmsd, output_rmsd = self.get_reward_rmsd(topo, mc, encoded_sf, self.rmsd_predictor) 181 | 182 | # target reward 183 | reward_target, output_target = self.get_reward_target(topo, mc, encoded_sf, self.target_predictor) 184 | 185 | # REINFORCE algorithm 186 | discounted_reward = reward_rmsd + reward_target 187 | total_reward += reward_rmsd + reward_target 188 | 189 | # accumulate trajectory 190 | hidden, stack = self.init_generator() 191 | 192 | # accumulate trajectory 193 | trajectory = torch.LongTensor(encoded_sf).cuda() 194 | 195 | for p in range(len(trajectory) - 1): 196 | 197 | output, hidden, stack = self.generator(trajectory[p], hidden, stack) 198 | log_probs = F.log_softmax(output, dim=-1) 199 | top_i = trajectory[p + 1] 200 | rl_loss -= (log_probs[0, top_i] * discounted_reward) 201 | discounted_reward = discounted_reward * gamma 202 | 203 | # backpropagation 204 | rl_loss = rl_loss / n_batch 205 | total_reward = total_reward / n_batch 206 | num_fail = num_fail / n_batch 207 | 208 | rl_loss.backward() 209 | self.generator.optimizer.step() 210 | 211 | return total_reward, rl_loss.item(), num_fail 212 | 213 | def test_estimate(self, n_to_generate, topo_idx=-1): 214 | with torch.no_grad(): 215 | self.generator.evaluate() 216 | gen_mofs = [] 217 | rewards = {"rmsd": [], "target": []} 218 | outputs = {"rmsd": [], "target": []} 219 | num_fail = 0 220 | 221 | for i in range(n_to_generate): 222 | # topology 223 | if topo_idx < 0: 224 | topo = torch.randint(0, len(self.topo_to_idx), size=(1,))[0].cuda() 225 | else: 226 | topo = torch.tensor(topo_idx).cuda() 227 | 228 | # metal cluster 229 | mc = self.generate_mc(topo) 230 | 231 | # organic linker 232 | encoded_sf, gen_sm, n_f = self.generate_ol() 233 | num_fail += n_f 234 | gen_mofs.append([topo.detach().item(), mc.detach().item(), gen_sm]) 235 | 236 | # rmsd reward 237 | reward_rmsd, output_rmsd = self.get_reward_rmsd(topo, mc, encoded_sf, self.rmsd_predictor) 238 | 239 | # target reward 240 | reward_target, output_target = self.get_reward_target(topo, mc, encoded_sf, self.target_predictor) 241 | 242 | rewards["rmsd"].append(reward_rmsd) 243 | rewards["target"].append(reward_target) 244 | 245 | outputs["rmsd"].append(output_rmsd) 246 | outputs["target"].append(output_target) 247 | 248 | return rewards, outputs, num_fail, gen_mofs 249 | 250 | 251 | def write_logs_test_and_estimate(self, test_reward, test_output, test_num_fail, gen_mofs, n_iter, n_to_generate): 252 | reward_rmsd = np.array(test_reward["rmsd"]) 253 | reward_target = np.array(test_reward["target"]) 254 | output_rmsd = np.array(test_output["rmsd"]) 255 | output_target = np.array(test_output["target"]) 256 | 257 | # add scalar to log 258 | self.writer.add_scalar("test/reward/rmsd", reward_rmsd.mean(), n_iter) 259 | self.writer.add_scalar("test/reward/target", reward_target.mean(), n_iter) 260 | self.writer.add_scalar("test/output/rmsd", output_rmsd.mean(), n_iter) 261 | self.writer.add_scalar("test/output/target", output_target.mean(), n_iter) 262 | self.writer.add_scalar("test/num_fail", test_num_fail / n_to_generate, n_iter) 263 | # add histogram to log 264 | fig, axes = plt.subplots(2, 2, figsize=(12, 8)) 265 | data = [reward_rmsd, reward_target, output_rmsd, output_target] 266 | title = ["reward_rmsd", "reward_target", "output_rmsd", "output_target"] 267 | for i in range(4): 268 | ax = axes[i // 2, i % 2] 269 | ax.set_title(title[i]) 270 | sns.histplot(data[i], ax=ax) 271 | self.writer.add_figure("test", fig, n_iter) 272 | 273 | # add image to log 274 | imgs = [] 275 | for i, m in enumerate(gen_mofs): 276 | m = Chem.MolFromSmiles(m[2]) 277 | img = MolToImage(m) 278 | img = np.array(img) 279 | img = torch.tensor(img) 280 | imgs.append(img) 281 | imgs = np.stack(imgs, axis=0) 282 | self.writer.add_images("test/gen_ol/", imgs, n_iter, dataformats="NHWC") 283 | 284 | 285 | def train(self, 286 | n_iters=10000, 287 | n_print=100, 288 | n_to_generate=200, 289 | n_batch=10, 290 | gamma=0.80, 291 | topo_idx=-1, 292 | ): 293 | """ 294 | 295 | :param n_iters: number of iterations 296 | :param n_print: number of iterations to print 297 | :param n_to_generate: how many mof will be generated when test_and_estimate 298 | :param n_batch: number of batch size 299 | :param gamma: discount ratio of REINFORCE 300 | :param topo_idx: if topo_idx < 0, randomly selecting topology 301 | :return: None 302 | """ 303 | 304 | reward = 0 305 | loss = 0 306 | num_fail = 0 307 | callback = 0 308 | 309 | for n_iter in range(n_iters): 310 | 311 | batch_reward, batch_loss, batch_num_fail = self.policy_gradient(n_batch, gamma, topo_idx) 312 | reward += batch_reward 313 | loss += batch_loss 314 | num_fail += batch_num_fail 315 | 316 | if n_iter % n_print == 0: 317 | self.writer.add_scalar("epoch", n_iter, n_iter) 318 | if n_iter != 0: 319 | print(f"########## iteration {n_iter} ##########") 320 | print( 321 | f"train | reward : {reward / n_print:.3f} ," 322 | f" loss : {loss / n_print:.3f} , " 323 | f"num_fail : {num_fail / n_print:.3f}") 324 | self.writer.add_scalar("train/loss", loss / n_print, n_iter) 325 | self.writer.add_scalar("train/reward", reward / n_print, n_iter) 326 | self.writer.add_scalar("train/num_fail", num_fail / n_print, n_iter) 327 | reward = 0 328 | loss = 0 329 | num_fail = 0 330 | 331 | test_reward, test_output, test_num_fail, gen_mofs = self.test_estimate(n_to_generate, topo_idx) 332 | 333 | test_reward_rmsd = np.mean(test_reward['rmsd']) 334 | test_reward_target = np.mean(test_reward["target"]) 335 | total_test_reward = test_reward_rmsd + test_reward_target 336 | 337 | test_output_rmsd = np.mean(test_output['rmsd']) 338 | test_output_target = np.mean(test_output["target"]) 339 | 340 | print( 341 | f"test | reward_rmsd : {test_reward_rmsd:.3f} , reward_target : {test_reward_target:.3f}, " 342 | f"output_rmsd : {test_output_rmsd:.3f}, output_target : {test_output_target:.3f}," 343 | f" num_fail : {test_num_fail}/{n_to_generate}") 344 | print(gen_mofs[:5]) 345 | 346 | self.write_logs_test_and_estimate(test_reward, 347 | test_output, 348 | test_num_fail, 349 | gen_mofs, 350 | n_iter, 351 | n_to_generate, 352 | ) 353 | 354 | if total_test_reward > callback: 355 | path_ckpt = os.path.join(self.log_dir, f"reinforce_model_{n_iter}.ckpt") 356 | torch.save(self.generator.state_dict(), path_ckpt) 357 | print("model save !!!") 358 | callback = total_test_reward 359 | -------------------------------------------------------------------------------- /mofreinforce/run_generator.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | 4 | import pytorch_lightning as pl 5 | from pytorch_lightning.strategies.ddp import DDPStrategy 6 | from pytorch_lightning.loggers import TensorBoardLogger 7 | 8 | from generator.config_generator import ex 9 | from generator.datamodule import GeneratorDatamodule 10 | from generator.module import Generator 11 | 12 | 13 | @ex.automain 14 | def main(_config): 15 | _config = copy.deepcopy(_config) 16 | pl.seed_everything(_config["seed"]) 17 | 18 | dm = GeneratorDatamodule(_config) 19 | 20 | model = Generator(_config) 21 | 22 | exp_name = f"{_config['exp_name']}" 23 | 24 | os.makedirs(_config["log_dir"], exist_ok=True) 25 | checkpoint_callback = pl.callbacks.ModelCheckpoint( 26 | save_top_k=1, 27 | verbose=True, 28 | monitor="val/the_metric", 29 | mode="max", 30 | save_last=True, 31 | ) 32 | 33 | logger = pl.loggers.TensorBoardLogger( 34 | _config["log_dir"], 35 | name=f'{exp_name}_seed{_config["seed"]}_from_{_config["load_path"].split("/")[-1][:-5]}', 36 | ) 37 | 38 | lr_callback = pl.callbacks.LearningRateMonitor(logging_interval="step") 39 | callbacks = [checkpoint_callback, lr_callback] 40 | 41 | num_devices = _config["devices"] 42 | if isinstance(num_devices, list): 43 | num_devices = len(num_devices) 44 | 45 | # gradient accumulation 46 | accumulate_grad_batches = _config["batch_size"] // ( 47 | _config["per_gpu_batchsize"] * num_devices * _config["num_nodes"] 48 | ) 49 | 50 | trainer = pl.Trainer( 51 | accelerator="gpu", 52 | devices=num_devices, 53 | num_nodes=_config["num_nodes"], 54 | precision=_config["precision"], 55 | deterministic=True, 56 | strategy=DDPStrategy(find_unused_parameters=True), 57 | max_epochs=_config["max_epochs"], 58 | callbacks=callbacks, 59 | logger=logger, 60 | accumulate_grad_batches=accumulate_grad_batches, 61 | log_every_n_steps=10, 62 | resume_from_checkpoint=_config["resume_from"], 63 | val_check_interval=_config["val_check_interval"], 64 | gradient_clip_val=_config["gradient_clip_val"], 65 | # profiler="simple", 66 | ) 67 | 68 | if not _config["test_only"]: 69 | trainer.fit(model, datamodule=dm) 70 | trainer.test(model, datamodule=dm) 71 | else: 72 | trainer.test(model, datamodule=dm) 73 | 74 | 75 | -------------------------------------------------------------------------------- /mofreinforce/run_predictor.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | 4 | import pytorch_lightning as pl 5 | from pytorch_lightning.strategies.ddp import DDPStrategy 6 | from pytorch_lightning.loggers import TensorBoardLogger 7 | 8 | from predictor.config_predictor import ex 9 | from predictor.datamodule import Datamodule 10 | from predictor.module import Predictor 11 | 12 | 13 | @ex.automain 14 | def main(_config): 15 | _config = copy.deepcopy(_config) 16 | pl.seed_everything(_config["seed"]) 17 | 18 | dm = Datamodule(_config) 19 | 20 | model = Predictor(_config) 21 | 22 | exp_name = f"{_config['exp_name']}" 23 | 24 | os.makedirs(_config["log_dir"], exist_ok=True) 25 | checkpoint_callback = pl.callbacks.ModelCheckpoint( 26 | save_top_k=1, 27 | verbose=True, 28 | monitor="val/the_metric", 29 | mode="max", 30 | save_last=True, 31 | ) 32 | 33 | logger = pl.loggers.TensorBoardLogger( 34 | _config["log_dir"], 35 | name=f'{exp_name}_seed{_config["seed"]}_from_{_config["load_path"].split("/")[-1][:-5]}', 36 | ) 37 | 38 | lr_callback = pl.callbacks.LearningRateMonitor(logging_interval="step") 39 | callbacks = [checkpoint_callback, lr_callback] 40 | 41 | num_gpus = _config["devices"] 42 | if isinstance(num_gpus, list): 43 | num_gpus = len(num_gpus) 44 | 45 | # gradient accumulation 46 | accumulate_grad_batches = _config["batch_size"] // ( 47 | _config["per_gpu_batchsize"] * num_gpus * _config["num_nodes"] 48 | ) 49 | 50 | trainer = pl.Trainer( 51 | gpus=num_gpus, 52 | num_nodes=_config["num_nodes"], 53 | precision=_config["precision"], 54 | deterministic=True, 55 | strategy=DDPStrategy(find_unused_parameters=True), 56 | max_epochs=_config["max_epochs"], 57 | callbacks=callbacks, 58 | logger=logger, 59 | accumulate_grad_batches=accumulate_grad_batches, 60 | log_every_n_steps=10, 61 | resume_from_checkpoint=_config["resume_from"], 62 | val_check_interval=_config["val_check_interval"], 63 | # profiler="simple", 64 | ) 65 | 66 | if not _config["test_only"]: 67 | trainer.fit(model, datamodule=dm) 68 | trainer.test(model, datamodule=dm) 69 | else: 70 | trainer.test(model, datamodule=dm) 71 | -------------------------------------------------------------------------------- /mofreinforce/run_reinforce.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | 4 | import pytorch_lightning as pl 5 | from pytorch_lightning.strategies.ddp import DDPStrategy 6 | from pytorch_lightning.loggers import TensorBoardLogger 7 | 8 | from reinforce.config_reinforce import ex 9 | 10 | from predictor.config_predictor import config as predictor_config 11 | from predictor.config_predictor import _loss_names 12 | from predictor.module import Predictor 13 | 14 | from generator.config_generator import config as generator_config 15 | from generator.datamodule import GeneratorDatamodule 16 | from generator.module import Generator 17 | 18 | from reinforce.module import Reinforce 19 | 20 | @ex.automain 21 | def main(_config): 22 | pl.seed_everything(_config["seed"]) 23 | # 1. load predictor 24 | predictors = [] 25 | for i in range(len(_config["predictor_load_path"])): 26 | pred_config = predictor_config() 27 | pred_config["test_only"] = True 28 | pred_config["loss_names"] = _loss_names({"regression": 1}) 29 | pred_config["load_path"] = _config["predictor_load_path"][i] 30 | pred_config["mean"] = _config["mean"][i] 31 | pred_config["std"] = _config["std"][i] 32 | 33 | predictor = Predictor(pred_config) 34 | predictor.eval() 35 | predictors.append(predictor) 36 | 37 | # 2. load generator 38 | gen_config = generator_config() 39 | gen_config["load_path"] = _config["generator_load_path"] 40 | gen_config["batch_size"] = _config["batch_size"] 41 | gen_config["dataset_dir"] = _config["dataset_dir"] 42 | 43 | generator = Generator(gen_config) 44 | dm = GeneratorDatamodule(gen_config) 45 | 46 | # 3. set reinforce 47 | _config = copy.deepcopy(_config) 48 | pl.seed_everything(_config["seed"]) 49 | 50 | model = Reinforce(generator, predictors, _config) 51 | 52 | exp_name = f"{_config['exp_name']}" 53 | os.makedirs(_config["log_dir"], exist_ok=True) 54 | checkpoint_callback = pl.callbacks.ModelCheckpoint( 55 | save_top_k=-1, 56 | verbose=True, 57 | monitor="val/total_reward", 58 | mode="max", 59 | save_last=True, 60 | ) 61 | 62 | logger = pl.loggers.TensorBoardLogger( 63 | _config["log_dir"], 64 | name=f'{exp_name}_seed{_config["seed"]}_from_{_config["load_path"].split("/")[-1][:-5]}', 65 | ) 66 | lr_callback = pl.callbacks.LearningRateMonitor(logging_interval="step") 67 | callbacks = [checkpoint_callback, lr_callback] 68 | 69 | trainer = pl.Trainer( 70 | accelerator="gpu", 71 | devices=_config["devices"], 72 | num_nodes=_config["num_nodes"], 73 | precision=_config["precision"], 74 | deterministic=True, 75 | strategy=DDPStrategy(find_unused_parameters=True), 76 | max_epochs=_config["max_epochs"], 77 | callbacks=callbacks, 78 | logger=logger, 79 | accumulate_grad_batches=_config["accumulate_grad_batches"], 80 | log_every_n_steps=10, 81 | resume_from_checkpoint=_config["resume_from"], 82 | limit_train_batches=_config["limit_train_batches"], 83 | limit_val_batches=_config["limit_val_batches"], 84 | num_sanity_val_steps=_config["limit_val_batches"], 85 | val_check_interval=_config["val_check_interval"], 86 | gradient_clip_val=_config["gradient_clip_val"], 87 | ) 88 | 89 | if not _config["test_only"]: 90 | trainer.fit(model, datamodule=dm) 91 | # trainer.test(model, datamodule=dm) 92 | else: 93 | trainer.test(model, datamodule=dm) 94 | 95 | 96 | -------------------------------------------------------------------------------- /mofreinforce/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hspark1212/MOFreinforce/b920084a29db9223482ddfeffd2e775b82f49e63/mofreinforce/utils/__init__.py -------------------------------------------------------------------------------- /mofreinforce/utils/download.py: -------------------------------------------------------------------------------- 1 | import os 2 | import wget 3 | import tarfile 4 | from pathlib import Path 5 | from mofreinforce import __root_dir__ 6 | 7 | 8 | DEFAULT_PATH = Path(__root_dir__) 9 | 10 | 11 | class DownloadError(Exception): 12 | pass 13 | 14 | 15 | def _remove_tmp_file(direc: Path): 16 | tmp_list = direc.parent.glob("*.tmp") 17 | for tmp in tmp_list: 18 | if tmp.exists(): 19 | os.remove(tmp) 20 | 21 | 22 | def _download_file(link, direc, name="target"): 23 | if direc.exists(): 24 | print(f"{name} already exists.") 25 | return 26 | try: 27 | print(f"\n====Download {name} =============================================\n") 28 | filename = wget.download(link, out=str(direc)) 29 | except KeyboardInterrupt: 30 | _remove_tmp_file(direc) 31 | raise 32 | except Exception as e: 33 | _remove_tmp_file(direc) 34 | raise DownloadError(e) 35 | else: 36 | print( 37 | f"\n====Successfully download : {filename}=======================================\n" 38 | ) 39 | 40 | 41 | def download_default(direc=None, remove_tarfile=False): 42 | """ 43 | downlaod data and pre-trained models including a generator, predictors for DAC 44 | """ 45 | if not direc: 46 | direc = Path(DEFAULT_PATH) 47 | if not direc.exists(): 48 | direc.mkdir(parents=True, exist_ok=True) 49 | direc = direc / "default.tar.gz" 50 | else: 51 | direc = Path(direc) 52 | if direc.is_dir(): 53 | if not direc.exists(): 54 | direc.mkdir(parents=True, exist_ok=True) 55 | direc = direc / "default.tar.gz" 56 | else: 57 | raise ValueError(f"direc must be path for directory, not {direc}") 58 | 59 | link = "https://figshare.com/ndownloader/files/40215109" 60 | name = "basic data and pretrained models" 61 | _download_file(link, direc, name) 62 | 63 | print(f"\n====Unzip : {name}===============================================\n") 64 | with tarfile.open(direc) as f: 65 | f.extractall(path=direc.parent) 66 | 67 | print( 68 | f"\n====Unzip successfully: {name}===============================================\n" 69 | ) 70 | 71 | if remove_tarfile: 72 | os.remove(direc) 73 | -------------------------------------------------------------------------------- /mofreinforce/utils/gadgets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchmetrics import Metric 3 | 4 | 5 | class Accuracy(Metric): 6 | def __init__(self, dist_sync_on_step=False): 7 | super().__init__(dist_sync_on_step=dist_sync_on_step) 8 | self.add_state("correct", default=torch.tensor(0.0), dist_reduce_fx="sum") 9 | self.add_state("total", default=torch.tensor(0.0), dist_reduce_fx="sum") 10 | 11 | def update(self, logits, target): 12 | logits, target = ( 13 | logits.detach().to(self.correct.device), 14 | target.detach().to(self.correct.device), 15 | ) 16 | if len(logits.shape) > 1: 17 | preds = logits.argmax(dim=-1) 18 | else: 19 | # binary accuracy 20 | logits[logits >= 0.5] = 1 21 | logits[logits < 0.5] = 0 22 | preds = logits 23 | 24 | preds = preds[target != 0] # pad_idx 25 | target = target[target != 0] # pad_idx 26 | 27 | if target.numel() == 0: 28 | return 1 29 | 30 | assert preds.shape == target.shape 31 | 32 | self.correct += torch.sum(preds == target) 33 | self.total += target.numel() 34 | 35 | def compute(self): 36 | return self.correct / self.total 37 | 38 | 39 | class Scalar(Metric): 40 | def __init__(self, dist_sync_on_step=False): 41 | super().__init__(dist_sync_on_step=dist_sync_on_step) 42 | self.add_state("scalar", default=torch.tensor(0.0), dist_reduce_fx="sum") 43 | self.add_state("total", default=torch.tensor(0.0), dist_reduce_fx="sum") 44 | 45 | def update(self, scalar): 46 | if isinstance(scalar, torch.Tensor): 47 | scalar = scalar.detach().to(self.scalar.device) 48 | else: 49 | scalar = torch.tensor(scalar).float().to(self.scalar.device) 50 | self.scalar += scalar 51 | self.total += 1 52 | 53 | def compute(self): 54 | return self.scalar / self.total -------------------------------------------------------------------------------- /mofreinforce/utils/metrics.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | from rdkit import Chem 4 | from rdkit import RDLogger 5 | 6 | import libs.selfies as sf 7 | 8 | RDLogger.DisableLog('rdApp.*') 9 | 10 | topo_to_cn = json.load(open("data/final_topo_cn.json")) 11 | mc_to_cn = json.load(open("data/mc_cn.json")) 12 | 13 | 14 | class Metrics(): 15 | def __init__(self, vocab_to_idx, idx_to_vocab): 16 | # generator 17 | self.num_fail = [] 18 | self.conn_match = [] 19 | self.scaffold = [] 20 | # collect 21 | self.input_frags = [] 22 | self.gen_ol = [] 23 | self.gen_topo = [] 24 | self.gen_mc = [] 25 | # reinforce 26 | self.rewards = [] 27 | self.preds = [] 28 | # vocab 29 | self.topo_to_cn = topo_to_cn 30 | self.mc_to_cn = mc_to_cn 31 | self.vocab_to_idx = vocab_to_idx 32 | self.idx_to_vocab = idx_to_vocab 33 | 34 | def update(self, output, src, rewards=None, preds=None): 35 | """ 36 | (1) metric for connection matching 37 | (2) metric for unique organic linker 38 | (3) metric for unique topology and metal cluster 39 | (4) metric for scaffold 40 | (5) (optional) rewards 41 | (6) (optional) preds 42 | 43 | :param output: 44 | :return: 45 | """ 46 | topo_cn = topo_to_cn.get(output["topo"], [0]) # if topo is [PAD] then topo_cn = [0] 47 | if len(topo_cn) == 1: 48 | topo_cn.append(2) 49 | mc_cn = mc_to_cn.get(output["mc"], -1) # if mc is [PAD] then mc_cn = -1 50 | ol_cn = output["ol_idx"].count(self.vocab_to_idx["[*]"]) 51 | 52 | # (1) metric for connection matching 53 | if set(topo_cn) == {mc_cn, ol_cn}: 54 | self.conn_match.append(1) 55 | else: 56 | self.conn_match.append(0) 57 | 58 | # (2) metric for unique organic linker 59 | # (3) metric for unique topology and metal cluster 60 | gen_sm = output["gen_sm"] 61 | self.gen_ol.append(gen_sm) 62 | self.gen_topo.append(output["topo"]) 63 | self.gen_mc.append(output["mc"]) 64 | 65 | # (4) metric for scaffold 66 | # get frags 67 | frags = [self.idx_to_vocab[idx.item()] for idx in src.squeeze(0)[3:]] 68 | frags = frags[:frags.index("[EOS]")] 69 | frags = "".join(frags) 70 | 71 | frags_sm = [sf.decoder(f) for f in frags.split(".")] 72 | self.input_frags.append(frags_sm) 73 | # replace * with H 74 | du, hy = Chem.MolFromSmiles('*'), Chem.MolFromSmiles('[H]') 75 | try: 76 | m = Chem.MolFromSmiles(gen_sm) 77 | check_gen_sm = Chem.ReplaceSubstructs(m, du, hy, replaceAll=True)[0] 78 | 79 | check_ = True 80 | for sm in frags_sm: 81 | if not check_gen_sm.HasSubstructMatch(Chem.MolFromSmiles(sm)): 82 | check_ = False 83 | if check_: 84 | self.scaffold.append(1) 85 | else: 86 | self.scaffold.append(0) 87 | except: 88 | pass 89 | 90 | # (5) metric for reward 91 | self.rewards.append(rewards) 92 | self.preds.append(preds) 93 | 94 | @staticmethod 95 | def get_mean(list_): 96 | return torch.mean(torch.Tensor(list_)) 97 | -------------------------------------------------------------------------------- /mofreinforce/utils/module_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from transformers import AdamW 4 | from transformers import ( 5 | get_polynomial_decay_schedule_with_warmup, 6 | get_cosine_schedule_with_warmup, 7 | get_constant_schedule, 8 | get_constant_schedule_with_warmup, 9 | ) 10 | from utils.gadgets import Scalar, Accuracy 11 | 12 | 13 | def init_weights(module): 14 | if isinstance(module, (nn.Linear, nn.Embedding)): 15 | module.weight.data.normal_(mean=0.0, std=0.02) 16 | elif isinstance(module, nn.LayerNorm): 17 | module.bias.data.zero_() 18 | module.weight.data.fill_(1.0) 19 | 20 | if isinstance(module, nn.Linear) and module.bias is not None: 21 | module.bias.data.zero_() 22 | 23 | 24 | def set_metrics(pl_module): 25 | for split in ["train", "val"]: 26 | for k, v in pl_module.hparams.config["loss_names"].items(): 27 | if v < 1: 28 | continue 29 | if k == "regression" or k == "vfr": 30 | setattr(pl_module, f"{split}_{k}_loss", Scalar()) 31 | setattr(pl_module, f"{split}_{k}_mae", Scalar()) 32 | setattr(pl_module, f"{split}_{k}_r2", Scalar()) 33 | elif k == "generator": 34 | setattr(pl_module, f"{split}_{k}_loss", Scalar()) 35 | setattr(pl_module, f"{split}_{k}_acc_topo", Scalar()) 36 | setattr(pl_module, f"{split}_{k}_acc_mc", Scalar()) 37 | setattr(pl_module, f"{split}_{k}_acc_ol", Scalar()) 38 | else: 39 | setattr(pl_module, f"{split}_{k}_accuracy", Accuracy()) 40 | setattr(pl_module, f"{split}_{k}_loss", Scalar()) 41 | 42 | 43 | def set_task(pl_module): 44 | pl_module.current_tasks = [ 45 | k for k, v in pl_module.hparams.config["loss_names"].items() if v >= 1 46 | ] 47 | return 48 | 49 | 50 | def epoch_wrapup(pl_module): 51 | phase = "train" if pl_module.training else "val" 52 | 53 | the_metric = 0 54 | for loss_name, v in pl_module.hparams.config["loss_names"].items(): 55 | if v < 1: 56 | continue 57 | 58 | if loss_name == "regression" or loss_name == "vfr": 59 | # mse loss 60 | pl_module.log( 61 | f"{loss_name}/{phase}/loss_epoch", 62 | getattr(pl_module, f"{phase}_{loss_name}_loss").compute(), 63 | ) 64 | getattr(pl_module, f"{phase}_{loss_name}_loss").reset() 65 | # mae loss 66 | value = getattr(pl_module, f"{phase}_{loss_name}_mae").compute() 67 | pl_module.log( 68 | f"{loss_name}/{phase}/mae_epoch", 69 | value, 70 | ) 71 | getattr(pl_module, f"{phase}_{loss_name}_mae").reset() 72 | 73 | value = -value 74 | elif loss_name == "generator": 75 | acc_topo = getattr(pl_module, f"{phase}_{loss_name}_acc_topo").compute() 76 | pl_module.log(f"{loss_name}/{phase}/acc_topo", acc_topo) 77 | getattr(pl_module, f"{phase}_{loss_name}_acc_topo").reset() 78 | 79 | acc_mc = getattr(pl_module, f"{phase}_{loss_name}_acc_mc").compute() 80 | pl_module.log(f"{loss_name}/{phase}/acc_mc", acc_mc) 81 | getattr(pl_module, f"{phase}_{loss_name}_acc_mc").reset() 82 | 83 | acc_ol = getattr(pl_module, f"{phase}_{loss_name}_acc_ol").compute() 84 | pl_module.log(f"{loss_name}/{phase}/acc_ol", acc_ol) 85 | getattr(pl_module, f"{phase}_{loss_name}_acc_ol").reset() 86 | value = acc_topo + acc_mc + acc_ol 87 | 88 | else: 89 | value = getattr(pl_module, f"{phase}_{loss_name}_accuracy").compute() 90 | pl_module.log(f"{loss_name}/{phase}/accuracy_epoch", value) 91 | getattr(pl_module, f"{phase}_{loss_name}_accuracy").reset() 92 | pl_module.log( 93 | f"{loss_name}/{phase}/loss_epoch", 94 | getattr(pl_module, f"{phase}_{loss_name}_loss").compute(), 95 | ) 96 | getattr(pl_module, f"{phase}_{loss_name}_loss").reset() 97 | 98 | the_metric += value 99 | 100 | pl_module.log(f"{phase}/the_metric", the_metric) 101 | 102 | 103 | def set_schedule(pl_module): 104 | lr = pl_module.hparams.config["learning_rate"] 105 | wd = pl_module.hparams.config["weight_decay"] 106 | 107 | no_decay = [ 108 | "bias", 109 | "LayerNorm.bias", 110 | "LayerNorm.weight", 111 | "norm.bias", 112 | "norm.weight", 113 | "norm1.bias", 114 | "norm1.weight", 115 | "norm2.bias", 116 | "norm2.weight", 117 | ] 118 | head_names = ["regression_head", "classification_head"] 119 | lr_mult = pl_module.hparams.config["lr_mult"] 120 | end_lr = pl_module.hparams.config["end_lr"] 121 | decay_power = pl_module.hparams.config["decay_power"] 122 | optim_type = pl_module.hparams.config["optim_type"] 123 | 124 | optimizer_grouped_parameters = [ 125 | { 126 | "params": [ 127 | p 128 | for n, p in pl_module.named_parameters() 129 | if not any(nd in n for nd in no_decay) # not within no_decay 130 | and not any(bb in n for bb in head_names) # not within head_names 131 | ], 132 | "weight_decay": wd, 133 | "lr": lr, 134 | }, 135 | { 136 | "params": [ 137 | p 138 | for n, p in pl_module.named_parameters() 139 | if any(nd in n for nd in no_decay) # within no_decay 140 | and not any(bb in n for bb in head_names) # not within head_names 141 | ], 142 | "weight_decay": 0.0, 143 | "lr": lr, 144 | }, 145 | { 146 | "params": [ 147 | p 148 | for n, p in pl_module.named_parameters() 149 | if not any(nd in n for nd in no_decay) # not within no_decay 150 | and any(bb in n for bb in head_names) # within head_names 151 | ], 152 | "weight_decay": wd, 153 | "lr": lr * lr_mult, 154 | }, 155 | { 156 | "params": [ 157 | p 158 | for n, p in pl_module.named_parameters() 159 | if any(nd in n for nd in no_decay) and any(bb in n for bb in head_names) 160 | # within no_decay and head_names 161 | ], 162 | "weight_decay": 0.0, 163 | "lr": lr * lr_mult, 164 | }, 165 | ] 166 | 167 | if optim_type == "adamw": 168 | optimizer = AdamW(optimizer_grouped_parameters, lr=lr, eps=1e-8, betas=(0.9, 0.98) 169 | ) 170 | elif optim_type == "adam": 171 | optimizer = torch.optim.Adam(optimizer_grouped_parameters, lr=lr) 172 | elif optim_type == "sgd": 173 | optimizer = torch.optim.SGD(optimizer_grouped_parameters, lr=lr, momentum=0.9) 174 | 175 | if pl_module.trainer.max_steps == -1: 176 | max_steps = ( 177 | len(pl_module.trainer.datamodule.train_dataloader()) 178 | * pl_module.trainer.max_epochs 179 | // pl_module.trainer.accumulate_grad_batches 180 | // (pl_module.trainer.num_devices * pl_module.trainer.num_nodes) 181 | ) 182 | else: 183 | max_steps = pl_module.trainer.max_steps 184 | 185 | warmup_steps = pl_module.hparams.config["warmup_steps"] 186 | if isinstance(pl_module.hparams.config["warmup_steps"], float): 187 | warmup_steps = int(max_steps * warmup_steps) 188 | 189 | print(f"max_epochs: {pl_module.trainer.max_epochs} | max_steps: {max_steps} | warmup_steps : {warmup_steps} " 190 | f"| weight_decay : {wd} | decay_power : {decay_power}") 191 | 192 | if decay_power == "cosine": 193 | scheduler = get_cosine_schedule_with_warmup( 194 | optimizer, 195 | num_warmup_steps=warmup_steps, 196 | num_training_steps=max_steps, 197 | ) 198 | elif decay_power == 'constant': 199 | scheduler = get_constant_schedule( 200 | optimizer, 201 | ) 202 | elif decay_power == 'constant_with_warmup': 203 | scheduler = get_constant_schedule_with_warmup( 204 | optimizer, 205 | num_warmup_steps=warmup_steps, 206 | ) 207 | else: 208 | scheduler = get_polynomial_decay_schedule_with_warmup( 209 | optimizer, 210 | num_warmup_steps=warmup_steps, 211 | num_training_steps=max_steps, 212 | lr_end=end_lr, 213 | power=decay_power, 214 | ) 215 | 216 | sched = {"scheduler": scheduler, "interval": "step"} 217 | 218 | return ( 219 | [optimizer], 220 | [sched], 221 | ) 222 | 223 | 224 | class Normalizer(object): 225 | """ 226 | normalize for regression 227 | """ 228 | 229 | def __init__(self, mean, std): 230 | if mean and std: 231 | self._norm_func = lambda tensor: (tensor - mean) / std 232 | self._denorm_func = lambda tensor: tensor * std + mean 233 | else: 234 | self._norm_func = lambda tensor: tensor 235 | self._denorm_func = lambda tensor: tensor 236 | 237 | self.mean = mean 238 | self.std = std 239 | 240 | def encode(self, tensor): 241 | return self._norm_func(tensor) 242 | 243 | def decode(self, tensor): 244 | return self._denorm_func(tensor) 245 | -------------------------------------------------------------------------------- /predictor.md: -------------------------------------------------------------------------------- 1 | # Predictor 2 | ## 1. Prepare dataset 3 | 4 | Once you download default data by running the following command, 5 | ```angular2html 6 | $ mofreinforce download default 7 | ``` 8 | 9 | Then, examples of dataset of predictors will be downloaded at `data/dataset_predictor/qkh` or `data/dataset_predictor/selectivity` 10 | 11 | The dataset directory should include `train.json`, `val.json` and `test.json`. 12 | 13 | The json consists of names of structures (key) and dictionary of descriptions of structures (values). 14 | The descriptions include `topo_name`, `mc_name`, `ol_name`, `ol_selfies`, `topo`, `mc`, `ol`, `target`. 15 | 16 | (optional) 17 | - `topo_name` : (string) name of topology. 18 | - `mc_name` : (string) name of metal cluster. 19 | - `ol_name` : (string) name of organic linker. 20 | (required) 21 | The topologies, metal clusters and organic linkers should be vectorized. 22 | Given topologies and metal clusters are categorical variables, they need to be converted to idx. 23 | - `topo` : (int) index of topology. The index can be found in `data/mc_to_idx.json` 24 | - `mc` : (int) index of metal cluster. The index can be found in `data/topo_to_idx.json` 25 | When it comes to organic linkers, it is represented by SELFIES. 26 | - `ol_selfies` : (string) SELFIES string of organic linker. The smiles can be converted into SELFIES using `sf.decode(SMILES)` in `libs/selfies`. 27 | - `ol` : (list) a list of index of SELFIES string. The index of vocabulary of SELFIES can be found in `data/vocab_to_idx.json` 28 | Finally, the target property you want to optimize should be defined. 29 | - `target` : (float) target property. 30 | 31 | ## 2. Training predictor 32 | 33 | Here is an example to train predictor for heat of adsorption in the `mofreinforce` directory 34 | ```angular2html 35 | # python run_predictor.py with regression_qkh_round3 36 | ``` 37 | By modifying `predictor/config_predictor.py`, you can train your predictors. 38 | 39 | 40 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pytorch-lightning==1.7.7 2 | torch<2.0.0 3 | torchmetrics 4 | sacred 5 | numpy 6 | scikit-learn 7 | tqdm 8 | timm 9 | transformers 10 | SmilesPE 11 | rdkit 12 | wget 13 | # for tutorial.ipynb 14 | pandas 15 | jupyterlab 16 | matplotlib 17 | seaborn 18 | ase 19 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import re 2 | from setuptools import setup, find_packages 3 | 4 | try: 5 | import torch 6 | except ImportError: 7 | raise EnvironmentError('Torch must be installed before installation') 8 | 9 | with open("requirements.txt", "r") as f: 10 | install_requires = f.readlines() 11 | 12 | with open("README.md", "r") as f: 13 | long_description=f.read() 14 | 15 | extras_require = { 16 | 'docs': ['sphinx', 'livereload', 'myst-parser'] 17 | } 18 | 19 | # with open('mofreinforce/__init__.py') as f: 20 | # version = re.search(r"__version__ = '(?P.+)'", f.read()).group('version') 21 | 22 | 23 | setup( 24 | name='mofreinforce', 25 | version="0.0.1", 26 | description='mofreinforce', 27 | long_description=long_description, 28 | long_description_content_type='text/markdown', 29 | author='Hyunsoo Park', 30 | author_email='phs68660888@gmail.com', 31 | packages=find_packages(), 32 | package_data={'mofreinforce': []}, 33 | install_requires=install_requires, 34 | extras_require=extras_require, 35 | scripts=[], 36 | download_url='https://github.com/hspark1212/MOFreinforce', 37 | entry_points={'console_scripts':['mofreinforce=mofreinforce.cli.main:main']}, 38 | python_requires='>=3.8', 39 | ) --------------------------------------------------------------------------------