├── .gitignore
├── LICENSE
├── README.md
├── mofreinforce
├── __init__.py
├── cli
│ ├── __init__.py
│ ├── download.py
│ └── main.py
├── data_preparation.ipynb
├── generator
│ ├── __init__.py
│ ├── config_generator.py
│ ├── datamodule.py
│ ├── dataset.py
│ ├── logs
│ │ └── v0_test_seed0_from_generator
│ │ │ └── version_0
│ │ │ ├── events.out.tfevents.1676286793.park.160184.0
│ │ │ └── hparams.yaml
│ ├── module.py
│ ├── objectives.py
│ └── transformer
│ │ ├── __init__.py
│ │ ├── attention.py
│ │ ├── layers.py
│ │ └── transformer.py
├── libs
│ └── selfies
│ │ ├── README.md
│ │ ├── __init__.py
│ │ ├── bond_constraints.py
│ │ ├── compatibility.py
│ │ ├── constants.py
│ │ ├── decoder.py
│ │ ├── encoder.py
│ │ ├── exceptions.py
│ │ ├── grammar_rules.py
│ │ ├── mol_graph.py
│ │ └── utils
│ │ ├── __init__.py
│ │ ├── encoding_utils.py
│ │ ├── matching_utils.py
│ │ ├── selfies_utils.py
│ │ └── smiles_utils.py
├── predictor
│ ├── __init__.py
│ ├── baseline_model.py
│ ├── config_predictor.py
│ ├── config_predictor_ex.py
│ ├── datamodule.py
│ ├── dataset.py
│ ├── gadgets.py
│ ├── heads.py
│ ├── module.py
│ ├── objectives.py
│ └── transformer.py
├── reinforce
│ ├── __init__.py
│ ├── config_reinforce.py
│ ├── module.py
│ └── reinforce.py
├── run_generator.py
├── run_predictor.py
├── run_reinforce.py
├── tutorial.ipynb
└── utils
│ ├── __init__.py
│ ├── download.py
│ ├── gadgets.py
│ ├── metrics.py
│ └── module_utils.py
├── predictor.md
├── requirements.txt
└── setup.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 |
3 | *.txt
4 | *.out
5 | *.log
6 | *.ipynb
7 | *.sh
8 | *.npz
9 | *.tar.gz
10 | !data_preparation.ipynb
11 |
12 | .ipynb_checkpoints/
13 |
14 |
15 | # figure
16 | figure/
17 |
18 | # log directory
19 | predictor/logs/
20 | predictor/result_transformer
21 |
22 | # data directory
23 | 1_data_preprocessing/
24 |
25 | # generator
26 | generator/logs/
27 |
28 | # reinforce
29 | reinforce/logs*/
30 | reinforce/old_logs*/
31 |
32 | # lib
33 | PORMAKE/
34 |
35 | # big file size
36 | model/
37 | data/
38 | raw_data/
39 | _data/
40 | assets/
41 | results/
42 | test/
43 | logs/
44 | logs_official/
45 | tmp/
46 |
47 | # setuptools
48 | build/
49 | dist/
50 | .eggs/
51 | *.egg-info
52 | *.eggs
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Hyunsoo Park
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Reinforcement Learning Framework For MOFs 🚀
2 | 
3 | This repository is a reinforcement learning framework for Metal-Organic Frameworks (MOFs), designed to generate MOF structures with user-defined properties. 🔍
4 |
5 | The framework consists of two key components: the agent and the environment. The agent (i.e., generator) generates MOF structures by taking actions, which are then evaluated by the environment (i.e., predictor) to predict the properties of the generated MOFs. Based on the prediction, a reward is returned to the agent, which is then used to generate the next round of MOFs, continually improving the generation process.
6 |
7 | ## Installation - Get started in minutes! 🌟
8 |
9 | ### OS and hardware requirements
10 | This package requires Linux Ubuntu 20.04 or 22.04. For optimal performance, we recommend running it with GPUs.
11 |
12 | ### Dependencies
13 | This package requires Python 3.8 or higher.
14 |
15 | ### Install
16 | To install this package, install **PyTorch** (version 1.12.0 or higher) according to your environment, and then follow these steps:
17 |
18 | ```
19 | $ git clone https://github.com/hspark1212/MOFreinforce.git
20 | $ cd MOFreinforce
21 | $ pip install -e .
22 | ```
23 |
24 | ## Getting Started 💥
25 |
26 | ### [download pre-trained models](https://figshare.com/ndownloader/files/39472138)
27 |
28 | To train the reinforcement learning framework, you'll need to use pre-trained predictors for DAC and pre-trained generator. You can download by running the following command in the `MOFreinforce/mofreinforce` directory:
29 |
30 | ```angular2html
31 | $ mofreinforce download default
32 | ```
33 | Once downloaded, you can find the pre-trained generator and predictor models in the `mofreinforce/model` directory.
34 |
35 | ### [Predictor](https://github.com/hspark1212/MOFreinforce/blob/master/mofreinforce/predictor)
36 |
37 |
39 |
40 | In the model directory, you'll find the pre-trained predictors `model/predictor/preditor_qkh.ckpt` and `model/predictor/predictor_selectivity.ckpt` for CO2 heat of adsorption and CO2/H2O selectivity, respectively. If you want to train your own predictor for your desired property, you can refer to [predictor.md](https://github.com/hspark1212/MOFreinforce/blob/master/predictor.md).
41 |
42 | ### [Generator](https://github.com/hspark1212/MOFreinforce/blob/master/mofreinforce/generator)
43 |
44 |
46 |
47 | The pre-trained generator, which selects a topology and a metal cluster and creates an organic linker represented by a SELFIES string, was pre-trained with about 650,000 MOFs created by PORMAKE, allowing for generating feasible MOFs. The pre-trained generator `model/generator/generator.ckpt` can be found in the model directory.
48 |
49 | ### [Reinforcement Learning](https://github.com/hspark1212/MOFreinforce/blob/master/mofreinforce/reinforce)
50 | To implement reinforcement learning with CO2 heat of adsorption, run in the `mofreinforce` directory:
51 | ```angular2html
52 | $ python run_reinforce.py with v0_qkh_round3
53 | ```
54 |
55 | To implement reinforcement learning with CO2/H2O selectivity, run in the `mofreinforce` directory:
56 | ```angular2html
57 | $ python run_reinforce.py with v1_selectivity_round3
58 | ```
59 |
60 | You can experiment with other parameters by modifying the [`mofreinforce/reinforce/config_reinforce.py`](https://github.com/hspark1212/MOFreinforce/blob/master/mofreinforce/reinforce/config_reinforce.py) file. You can also train the reinforcement learning with your own pre-trained predictor to generate high-performing MOFs with your defined reward function.
61 |
62 | ### testing and construction of MOFs by PORMAKE
63 | To test the reinforcement learning, run in the `mofreinforce` directory:
64 | ```angular2html
65 | $ python run_reinforce.py with v0_qkh_round3 log_dir=test test_only=True load_path=model/reinforce/best_v0_qkh_round3.ckpt
66 | ```
67 | The optimized generators for CO2 heat of adsorption and CO2/H2O selectivity can be found in the `mofreinforce/model` directory.
68 |
69 | The generated MOFs obtained from the test set (10,000 data) can be constructed by the [PORMAKE](https://github.com/Sangwon91/PORMAKE).
70 | The details are summarized in [`tutorial.ipynb`](https://github.com/hspark1212/MOFreinforce/blob/master/mofreinforce/tutorial.ipynb) file.
71 |
72 | ## Contributing 🙌
73 |
74 | Contributions are welcome! If you have any suggestions or find any issues, please open an issue or a pull request.
75 |
76 | ## License 📄
77 |
78 | This project is licensed under the MIT License. See the `LICENSE` file for more information.
79 |
--------------------------------------------------------------------------------
/mofreinforce/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | __version__ = "0.0.1"
4 | __root_dir__ = os.path.dirname(__file__)
5 |
6 | from mofreinforce import predictor, generator, reinforce
7 |
8 | __all__ = ["predictor", "generator", "reinforce", __version__]
--------------------------------------------------------------------------------
/mofreinforce/cli/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hspark1212/MOFreinforce/b920084a29db9223482ddfeffd2e775b82f49e63/mofreinforce/cli/__init__.py
--------------------------------------------------------------------------------
/mofreinforce/cli/download.py:
--------------------------------------------------------------------------------
1 | class CLICommand:
2 | """
3 | Download data and pre-trained model including a generator, predictors for DAC
4 | """
5 |
6 | @staticmethod
7 | def add_arguments(parser):
8 | from mofreinforce.utils.download import DEFAULT_PATH
9 |
10 | add = parser.add_argument
11 | add(
12 | "target",
13 | nargs="+",
14 | help="download data and pretrained models including a generator and predictors for DAC",
15 | )
16 | add(
17 | "--outdir",
18 | "-o",
19 | help=f"The Path where downloaded data will be stored. \n"
20 | f"default : (default) {DEFAULT_PATH} \n",
21 | )
22 | add(
23 | "--remove_tarfile",
24 | "-r",
25 | action="store_true",
26 | help="remove tar.gz file for download database",
27 | )
28 |
29 | @staticmethod
30 | def run(args):
31 | from mofreinforce.utils.download import (
32 | download_default,
33 | )
34 |
35 | func_dic = {
36 | "default": download_default,
37 | }
38 |
39 | for stuff in args.target:
40 | if stuff not in func_dic.keys():
41 | raise ValueError(
42 | f'target must be {", ".join(func_dic.keys())}, not {stuff}'
43 | )
44 |
45 | for stuff in args.target:
46 | func = func_dic[stuff]
47 | if func.__code__.co_argcount == 1:
48 | func(args.outdir)
49 | else:
50 | func(args.outdir, args.remove_tarfile)
51 |
--------------------------------------------------------------------------------
/mofreinforce/cli/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import textwrap
3 | from importlib import import_module
4 | from argparse import RawTextHelpFormatter
5 | from mofreinforce import __version__
6 |
7 | commands = [
8 | ("download", "mofreinforce.cli.download"),
9 | ]
10 |
11 |
12 | def main(prog="mofreinforce", version=__version__, commands=commands, args=None):
13 | parser = argparse.ArgumentParser(
14 | prog=prog,
15 | )
16 | parser.add_argument(
17 | "--version", action="version", version="%(prog)s-{}".format(version)
18 | )
19 | parser.add_argument("-T", "--traceback", action="store_true")
20 | subparsers = parser.add_subparsers(title="Sub-command", dest="command")
21 |
22 | subparser = subparsers.add_parser(
23 | "help", description="Help", help="Help for sub-command."
24 | )
25 |
26 | subparser.add_argument(
27 | "helpcommand",
28 | nargs="?",
29 | metavar="sub-command",
30 | help="Provide help for sub-command.",
31 | )
32 |
33 | functions = {}
34 | parsers = {}
35 | for command, module_name in commands:
36 | cmd = import_module(module_name).CLICommand
37 | docstring = cmd.__doc__
38 | if docstring is None:
39 | # Backwards compatibility with GPAW
40 | short = cmd.short_description
41 | long = getattr(cmd, "description", short)
42 | else:
43 | parts = docstring.split("\n", 1)
44 | if len(parts) == 1:
45 | short = docstring
46 | long = docstring
47 | else:
48 | short, body = parts
49 | long = short
50 | # long = short + '\n' + textwrap.dedent(body)
51 | subparser = subparsers.add_parser(
52 | command, formatter_class=RawTextHelpFormatter, help=short, description=long
53 | )
54 | cmd.add_arguments(subparser)
55 | functions[command] = cmd.run
56 | parsers[command] = subparser
57 |
58 | # if hook:
59 | # args = hook(parser, args)
60 | # args = hook(parser, args)
61 | # else:
62 | args = parser.parse_args(args)
63 |
64 | if args.command == "help":
65 | if args.helpcommand is None:
66 | parser.print_help()
67 | else:
68 | parsers[args.helpcommand].print_help()
69 | elif args.command is None:
70 | parser.print_usage()
71 | else:
72 | f = functions[args.command]
73 | try:
74 | if f.__code__.co_argcount == 1:
75 | f(args)
76 | else:
77 | f(args, parsers[args.command])
78 | except KeyboardInterrupt:
79 | pass
80 | except Exception as x:
81 | if args.traceback:
82 | raise
83 | else:
84 | l1 = "{}: {}\n".format(x.__class__.__name__, x)
85 | l2 = "To get a full traceback, use: {} -T {} ...".format(
86 | prog, args.command
87 | )
88 | parser.error(l1 + l2)
89 |
--------------------------------------------------------------------------------
/mofreinforce/generator/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hspark1212/MOFreinforce/b920084a29db9223482ddfeffd2e775b82f49e63/mofreinforce/generator/__init__.py
--------------------------------------------------------------------------------
/mofreinforce/generator/config_generator.py:
--------------------------------------------------------------------------------
1 | import json
2 | from sacred import Experiment
3 |
4 | ex = Experiment("generator")
5 |
6 | mc_to_idx = json.load(open("data/mc_to_idx.json"))
7 | topo_to_idx = json.load(open("data/topo_to_idx.json"))
8 |
9 |
10 | @ex.config
11 | def config():
12 | seed = 0
13 | exp_name = "generator"
14 | log_dir = "generator/logs"
15 | loss_names = {"generator": 1}
16 |
17 | # datamodule
18 | dataset_dir = "data/dataset_generator"
19 | batch_size = 256
20 | num_workers = 8 # recommend num_gpus * 4
21 | max_len = 128
22 |
23 | # transformer
24 | path_topo_to_idx = "data/topo_to_idx.json"
25 | path_mc_to_idx = "data/mc_to_idx.json"
26 | path_vocab = "data/vocab_to_idx.json"
27 | # input_dim = len(vocab_to_idx)
28 | # output_dim = len(vocab_to_idx)
29 | hid_dim = 256
30 | n_layers = 3
31 | n_heads = 8
32 | pf_dim = 512
33 | dropout = 0.1
34 | max_len = 128
35 | src_pad_idx = 0
36 | trg_pad_idx = 0
37 |
38 | # Trainer
39 | per_gpu_batchsize = 128
40 | num_nodes = 1
41 | devices = 2
42 | precision = 16
43 | resume_from = None
44 | val_check_interval = 1.0
45 | test_only = False
46 | load_path = ""
47 | gradient_clip_val = None
48 |
49 | # Optimizer Setting
50 | optim_type = "adam" # adamw, adam, sgd (momentum=0.9)
51 | learning_rate = 5e-4
52 | weight_decay = 0
53 | decay_power = (
54 | "constant" # default polynomial decay, [cosine, constant, constant_with_warmup]
55 | )
56 | max_epochs = 100
57 | max_steps = -1 # num_data * max_epoch // batch_size (accumulate_grad_batches)
58 | warmup_steps = 0.0 # int or float ( max_steps * warmup_steps)
59 | end_lr = 0
60 | lr_mult = 1 # multiply lr for downstream heads
61 |
62 |
63 | @ex.named_config
64 | def v0():
65 | exp_name = "v0"
66 |
67 |
68 | @ex.named_config
69 | def test():
70 | exp_name = "test"
71 |
72 |
73 | @ex.named_config
74 | def v0_test():
75 | exp_name = "v0_test"
76 | load_path = "model/generator.ckpt"
77 |
78 | test_only = True
79 | num_devices = 1
80 |
81 |
82 | """
83 | old experiments
84 | """
85 |
86 |
87 | @ex.named_config
88 | def v0_grad_clip():
89 | exp_name = "v0_grad_clip"
90 |
91 | gradient_clip_val = 0.5
92 |
--------------------------------------------------------------------------------
/mofreinforce/generator/datamodule.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | from pytorch_lightning import LightningDataModule
4 | from torch.utils.data import DataLoader
5 | from generator.dataset import GeneratorDataset
6 |
7 |
8 | class GeneratorDatamodule(LightningDataModule):
9 | def __init__(self, _config):
10 | super(GeneratorDatamodule, self).__init__()
11 | self.dataset_dir = _config["dataset_dir"]
12 | self.batch_size = _config["batch_size"]
13 | self.num_workers = _config["num_workers"]
14 | self.max_len = _config["max_len"]
15 | self.path_vocab = _config["path_vocab"]
16 | self.path_topo_to_idx = _config["path_topo_to_idx"]
17 | self.path_mc_to_idx = _config["path_mc_to_idx"]
18 |
19 | @property
20 | def dataset_cls(self):
21 | return GeneratorDataset
22 |
23 | def set_train_dataset(self):
24 | self.train_dataset = self.dataset_cls(
25 | dataset_dir=self.dataset_dir,
26 | path_vocab=self.path_vocab,
27 | path_topo_to_idx=self.path_topo_to_idx,
28 | path_mc_to_idx=self.path_mc_to_idx,
29 | split="train",
30 | max_len=self.max_len,
31 | )
32 |
33 | def set_val_dataset(self):
34 | self.val_dataset = self.dataset_cls(
35 | dataset_dir=self.dataset_dir,
36 | path_vocab=self.path_vocab,
37 | path_topo_to_idx=self.path_topo_to_idx,
38 | path_mc_to_idx=self.path_mc_to_idx,
39 | split="val",
40 | max_len=self.max_len,
41 | )
42 |
43 | def set_test_dataset(self):
44 | self.test_dataset = self.dataset_cls(
45 | dataset_dir=self.dataset_dir,
46 | path_vocab=self.path_vocab,
47 | path_topo_to_idx=self.path_topo_to_idx,
48 | path_mc_to_idx=self.path_mc_to_idx,
49 | split="test",
50 | max_len=self.max_len,
51 | )
52 |
53 | def setup(self, stage: Optional[str] = None):
54 | if stage in (None, "fit"):
55 | self.set_train_dataset()
56 | self.set_val_dataset()
57 |
58 | if stage in (None, "test"):
59 | self.set_test_dataset()
60 |
61 | self.collate = self.dataset_cls.collate
62 |
63 | def train_dataloader(self):
64 | return DataLoader(
65 | self.train_dataset,
66 | batch_size=self.batch_size,
67 | num_workers=self.num_workers,
68 | collate_fn=self.collate,
69 | shuffle=True, ##
70 | pin_memory=True,
71 | )
72 |
73 | def val_dataloader(self):
74 | return DataLoader(
75 | self.val_dataset,
76 | batch_size=self.batch_size,
77 | num_workers=self.num_workers,
78 | collate_fn=self.collate,
79 | shuffle=False, ##
80 | pin_memory=True,
81 | )
82 |
83 | def test_dataloader(self):
84 | return DataLoader(
85 | self.test_dataset,
86 | batch_size=self.batch_size,
87 | num_workers=self.num_workers,
88 | collate_fn=self.collate,
89 | shuffle=False, ##
90 | pin_memory=True,
91 | )
92 |
--------------------------------------------------------------------------------
/mofreinforce/generator/dataset.py:
--------------------------------------------------------------------------------
1 | import json
2 | import numpy as np
3 | from pathlib import Path
4 | import pandas as pd
5 | import libs.selfies as sf
6 |
7 | import torch
8 | from torch.utils.data import Dataset
9 |
10 |
11 | class GeneratorDataset(Dataset):
12 | def __init__(
13 | self, dataset_dir, path_vocab, path_topo_to_idx, path_mc_to_idx, split, max_len
14 | ):
15 | assert split in ["train", "test", "val"]
16 | # read vocab_to_idx
17 | self.vocab_to_idx = json.load(open(path_vocab))
18 | self.topo_to_idx = json.load(open(path_topo_to_idx))
19 | self.mc_to_idx = json.load(open(path_mc_to_idx))
20 |
21 | # load dataset
22 | path_data = Path(dataset_dir, f"{split}.csv")
23 | csv_ = pd.read_csv(path_data)
24 | print(f"read file : {path_data}, num_data : {len(csv_)}")
25 | self.topo = np.array(csv_["topo"])
26 | self.mc = np.array(csv_["mc"])
27 | self.num_conn = np.array(csv_["num_conn"])
28 | self.frags = np.array(csv_["frags"])
29 | self.selfies = np.array(csv_["selfies"])
30 |
31 | self.encoded_input, self.encoded_output = self.encoding(max_len)
32 |
33 | def encoding(self, max_len):
34 | # making encoded_input
35 | encoded_input = []
36 | for i, f in enumerate(self.frags):
37 | encoded_frags = [self.vocab_to_idx[v] for v in sf.split_selfies(f)]
38 | encoded = (
39 | [self.mc_to_idx[self.mc[i]]]
40 | + [self.num_conn[i]]
41 | + [self.vocab_to_idx["[SOS]"]]
42 | + encoded_frags
43 | + [self.vocab_to_idx["[EOS]"]]
44 | + [self.vocab_to_idx["[PAD]"]] * (max_len - 4 - len(encoded_frags))
45 | )
46 | encoded_input.append(encoded)
47 |
48 | # making encoded_output
49 | encoded_output = []
50 | for i, f in enumerate(self.selfies):
51 | encoded_selfies = [self.vocab_to_idx[v] for v in sf.split_selfies(f)]
52 | encoded = (
53 | [self.vocab_to_idx["[SOS]"]]
54 | + [self.topo_to_idx[self.topo[i]]]
55 | + [self.mc_to_idx[self.mc[i]]]
56 | + encoded_selfies
57 | + [self.vocab_to_idx["[EOS]"]]
58 | + [self.vocab_to_idx["[PAD]"]] * (max_len - 4 - len(encoded_selfies))
59 | )
60 | encoded_output.append(encoded)
61 |
62 | return encoded_input, encoded_output
63 |
64 | def __len__(self):
65 | return len(self.selfies)
66 |
67 | def __getitem__(self, idx):
68 | ret = dict()
69 |
70 | ret.update(
71 | {
72 | "topo": self.topo[idx],
73 | "mc": self.mc[idx],
74 | "frags": self.frags[idx],
75 | "selfies": self.selfies[idx],
76 | "encoded_input": self.encoded_input[idx],
77 | "encoded_output": self.encoded_output[idx],
78 | }
79 | )
80 | return ret
81 |
82 | @staticmethod
83 | def collate(batch):
84 | keys = set([key for b in batch for key in b.keys()])
85 | dict_batch = {k: [dic[k] if k in dic else None for dic in batch] for k in keys}
86 | dict_batch["encoded_input"] = torch.LongTensor(dict_batch["encoded_input"])
87 | dict_batch["encoded_output"] = torch.LongTensor(dict_batch["encoded_output"])
88 | return dict_batch
89 |
--------------------------------------------------------------------------------
/mofreinforce/generator/logs/v0_test_seed0_from_generator/version_0/events.out.tfevents.1676286793.park.160184.0:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hspark1212/MOFreinforce/b920084a29db9223482ddfeffd2e775b82f49e63/mofreinforce/generator/logs/v0_test_seed0_from_generator/version_0/events.out.tfevents.1676286793.park.160184.0
--------------------------------------------------------------------------------
/mofreinforce/generator/logs/v0_test_seed0_from_generator/version_0/hparams.yaml:
--------------------------------------------------------------------------------
1 | config:
2 | batch_size: 256
3 | dataset_dir: data/dataset_generator
4 | decay_power: constant
5 | dropout: 0.1
6 | end_lr: 0
7 | exp_name: v0_test
8 | gradient_clip_val: null
9 | hid_dim: 256
10 | learning_rate: 0.0005
11 | load_path: model/generator.ckpt
12 | log_dir: generator/logs
13 | loss_names:
14 | generator: 1
15 | lr_mult: 1
16 | max_epochs: 100
17 | max_len: 128
18 | max_steps: -1
19 | n_heads: 8
20 | n_layers: 3
21 | num_devices: 1
22 | num_nodes: 1
23 | num_workers: 8
24 | optim_type: adam
25 | path_mc_to_idx: data/mc_to_idx.json
26 | path_topo_to_idx: data/topo_to_idx.json
27 | path_vocab: data/vocab_to_idx.json
28 | per_gpu_batchsize: 128
29 | pf_dim: 512
30 | precision: 16
31 | resume_from: null
32 | seed: 0
33 | src_pad_idx: 0
34 | test_only: true
35 | trg_pad_idx: 0
36 | val_check_interval: 1.0
37 | warmup_steps: 0.0
38 | weight_decay: 0
39 |
--------------------------------------------------------------------------------
/mofreinforce/generator/module.py:
--------------------------------------------------------------------------------
1 | import json
2 | import random
3 | from tqdm import tqdm
4 |
5 | import numpy as np
6 | import libs.selfies as sf
7 | from rdkit import Chem
8 |
9 | import torch
10 | from pytorch_lightning import LightningModule
11 |
12 | from generator import objectives
13 | from generator.transformer import Transformer
14 | from utils import module_utils
15 |
16 | from utils.metrics import Metrics
17 | from rdkit.Chem.Draw import MolToImage
18 |
19 |
20 | class Generator(LightningModule):
21 | def __init__(self, config):
22 | super(Generator, self).__init__()
23 | self.save_hyperparameters()
24 |
25 | self.max_len = config["max_len"]
26 | # topo
27 | path_topo_to_idx = config["path_topo_to_idx"]
28 | self.topo_to_idx = json.load(open(path_topo_to_idx))
29 | self.idx_to_topo = {v: k for k, v in self.topo_to_idx.items()}
30 | # mc
31 | path_mc_to_idx = config["path_mc_to_idx"]
32 | self.mc_to_idx = json.load(open(path_mc_to_idx))
33 | self.idx_to_mc = {v: k for k, v in self.mc_to_idx.items()}
34 | # ol
35 | path_vocab = config["path_vocab"]
36 | self.vocab_to_idx = json.load(open(path_vocab))
37 | self.idx_to_vocab = {v: k for k, v in self.vocab_to_idx.items()}
38 |
39 | self.transformer = Transformer(
40 | input_dim=len(self.vocab_to_idx),
41 | output_dim=len(self.vocab_to_idx),
42 | topo_dim=len(self.topo_to_idx),
43 | mc_dim=len(self.mc_to_idx),
44 | hid_dim=config["hid_dim"],
45 | n_layers=config["n_layers"],
46 | n_heads=config["n_heads"],
47 | pf_dim=config["pf_dim"],
48 | dropout=config["dropout"],
49 | max_len=config["max_len"],
50 | src_pad_idx=config["src_pad_idx"],
51 | trg_pad_idx=config["trg_pad_idx"],
52 | )
53 |
54 | module_utils.set_metrics(self)
55 | # ===================== load model ======================
56 |
57 | if config["load_path"] != "":
58 | ckpt = torch.load(config["load_path"], map_location="cpu")
59 | state_dict = ckpt["state_dict"]
60 | self.load_state_dict(state_dict, strict=False)
61 | print(f"load model : {config['load_path']}")
62 |
63 | def infer(self, batch):
64 | src = batch["encoded_input"] # [B, max_len]
65 |
66 | tgt_input = batch["encoded_output"][:, :-1] # [B, seq_len-1]
67 | tgt_label = batch["encoded_output"][:, 1:] # [B, seq_len-1]
68 |
69 | # get mask
70 | out = self.transformer(src, tgt_input) # [B, seq_len-1, vocab_dim]
71 |
72 | out.update(
73 | {
74 | "src": src,
75 | "tgt": batch["encoded_output"],
76 | "tgt_label": tgt_label,
77 | }
78 | )
79 |
80 | return out
81 |
82 | def evaluate(self, src, max_len=128):
83 | """
84 | src : torch.LongTensor [B=1, seq_len]
85 | """
86 | vocab_to_idx = self.vocab_to_idx
87 | src_mask = self.transformer.make_src_mask(src)
88 |
89 | # get encoded src
90 | enc_src = self.transformer.encoder(src, src_mask) # [B=1, seq_len, hid_dim]
91 |
92 | # get target
93 | trg_indexes = [vocab_to_idx["[SOS]"]]
94 | for i in range(max_len):
95 | trg_tensor = torch.LongTensor(trg_indexes).unsqueeze(0).to(src.device)
96 | trg_mask = self.transformer.make_trg_mask(
97 | trg_tensor
98 | ) # [B=1, 1, seq_len, seq_len]
99 |
100 | output = self.transformer.decoder(
101 | trg_tensor, enc_src, trg_mask, src_mask
102 | ) # [B=1, seq_len, vocab_dim]
103 |
104 | if "output_ol" in output.keys():
105 | out = output["output_ol"]
106 | pred_token = out.argmax(-1)[:, -1].item()
107 | elif "output_mc" in output.keys():
108 | out = output["output_mc"]
109 | pred_token = out.argmax(-1).item()
110 | else:
111 | out = output["output_topo"]
112 | pred_token = out.argmax(-1).item()
113 | trg_indexes.append(pred_token)
114 |
115 | if i > 3 and pred_token == vocab_to_idx["[EOS]"]:
116 | break
117 |
118 | # get topo
119 | topo_idx = trg_indexes[1]
120 | topo = self.idx_to_topo[topo_idx]
121 | # get mc
122 | mc_idx = trg_indexes[2]
123 | mc = self.idx_to_mc[mc_idx]
124 | # get ol
125 | ol_idx = trg_indexes[3:]
126 | ol_tokens = [self.idx_to_vocab[idx] for idx in ol_idx]
127 | # convert to selfies and smiles
128 | gen_sf = None
129 | gen_sm = None
130 | try:
131 | gen_sf = "".join(ol_tokens[:-1]) # remove EOS token
132 | gen_sm = sf.decoder(gen_sf)
133 | m = Chem.MolFromSmiles(gen_sm)
134 | gen_sm = Chem.MolToSmiles(m) # canonical smiles
135 | except Exception as e:
136 | print(e)
137 | pass
138 |
139 | ret = {
140 | "topo": topo,
141 | "mc": mc,
142 | "topo_idx": topo_idx,
143 | "mc_idx": mc_idx,
144 | "ol_idx": ol_idx,
145 | "gen_sf": gen_sf,
146 | "gen_sm": gen_sm,
147 | }
148 | return ret
149 |
150 | def forward(self, batch):
151 | ret = dict()
152 | ret.update(objectives.compute_loss(self, batch))
153 |
154 | return ret
155 |
156 | def training_step(self, batch, batch_idx):
157 | ret = self(batch)
158 | total_loss = sum([v for k, v in ret.items() if "loss" in k])
159 |
160 | return total_loss
161 |
162 | def training_epoch_end(self, outputs):
163 | module_utils.epoch_wrapup(self)
164 |
165 | def validation_step(self, batch, batch_idx):
166 | output = self(batch)
167 |
168 | def validation_epoch_end(self, output):
169 | module_utils.epoch_wrapup(self)
170 |
171 | def test_step(self, batch, batch_idx):
172 | return batch
173 |
174 | def test_epoch_end(self, batches):
175 | split = "test"
176 | module_utils.epoch_wrapup(self)
177 |
178 | metrics = Metrics(self.vocab_to_idx, self.idx_to_vocab)
179 | list_src = torch.concat([b["encoded_input"] for b in batches], dim=0)
180 |
181 | for src in tqdm(list_src):
182 | out = self.evaluate(src.unsqueeze(0))
183 |
184 | if out["gen_sm"] is None:
185 | metrics.num_fail.append(1)
186 | continue
187 | else:
188 | metrics.num_fail.append(0)
189 |
190 | metrics.update(out, src)
191 |
192 | self.log(f"{split}/conn_match", metrics.get_mean(metrics.conn_match))
193 | self.log(f"{split}/unique_ol", len(set(metrics.gen_ol)))
194 | self.log(
195 | f"{split}/unique_topo_mc", len(set(zip(metrics.gen_topo, metrics.gen_mc)))
196 | )
197 | self.log(f"{split}/scaffold", metrics.get_mean(metrics.scaffold))
198 | self.log(f"{split}/num_fail", metrics.get_mean(metrics.num_fail))
199 |
200 | # add image to log
201 | # gen_ol with frags (32 images)
202 | for i in range(32):
203 | idx = random.Random(i).choice(range(len(metrics.gen_ol)))
204 | ol = metrics.gen_ol[idx]
205 | frags = metrics.input_frags[idx]
206 | imgs = []
207 | for s in [ol] + frags:
208 | m = Chem.MolFromSmiles(s)
209 | if not m:
210 | continue
211 | img = MolToImage(m)
212 | img = np.array(img)
213 | img = torch.tensor(img)
214 | imgs.append(img)
215 | imgs = np.stack(imgs, axis=0)
216 | self.logger.experiment.add_image(
217 | f"{split}/{i}", imgs, self.global_step, dataformats="NHWC"
218 | )
219 |
220 | # total gen_ol
221 | imgs = []
222 | for i, m in enumerate(metrics.gen_ol[:32]):
223 | try:
224 | m = Chem.MolFromSmiles(m)
225 | img = MolToImage(m)
226 | img = np.array(img)
227 | img = torch.tensor(img)
228 | imgs.append(img)
229 | except Exception as e:
230 | print(e)
231 | imgs = np.stack(imgs, axis=0)
232 | self.logger.experiment.add_image(
233 | f"{split}/gen_ol/", imgs, self.global_step, dataformats="NHWC"
234 | )
235 |
236 | def configure_optimizers(self):
237 | return module_utils.set_schedule(self)
238 |
--------------------------------------------------------------------------------
/mofreinforce/generator/objectives.py:
--------------------------------------------------------------------------------
1 | import torch.nn.functional as F
2 | from torchmetrics.functional import accuracy
3 |
4 |
5 | def compute_loss(pl_module, batch):
6 | infer = pl_module.infer(batch)
7 |
8 | tgt_label = infer["tgt_label"] # [B, seq_len]
9 |
10 | batch_size, _, vocab_dim = infer["output_ol"].shape
11 |
12 | # loss topo
13 | logit_topo = infer["output_topo"]
14 | label_topo = tgt_label[:, 0]
15 | loss_topo = F.cross_entropy(logit_topo, label_topo)
16 | # loss mc
17 | logit_mc = infer["output_mc"]
18 | label_mc = tgt_label[:, 1]
19 | loss_mc = F.cross_entropy(logit_mc, label_mc)
20 | # loss ol
21 | logit_ol = infer["output_ol"].reshape(-1, vocab_dim)
22 | label_ol = tgt_label[:, 2:].reshape(-1)
23 | loss_ol = F.cross_entropy(logit_ol, label_ol, ignore_index=0)
24 | # total loss
25 | total_loss = loss_topo + loss_mc + loss_ol
26 |
27 | ret = {
28 | "gen_loss": total_loss,
29 | "gen_labels": tgt_label,
30 | }
31 |
32 | # call update() loss and acc
33 | loss_name = "generator"
34 | phase = "train" if pl_module.training else "val"
35 | total_loss = getattr(pl_module, f"{phase}_{loss_name}_loss")(ret["gen_loss"])
36 |
37 | # acc
38 | acc_topo = getattr(pl_module, f"{phase}_{loss_name}_acc_topo")(
39 | accuracy(logit_topo.argmax(-1), label_topo)
40 | )
41 | acc_mc = getattr(pl_module, f"{phase}_{loss_name}_acc_mc")(
42 | accuracy(logit_mc.argmax(-1), label_mc)
43 | )
44 | acc_ol = getattr(pl_module, f"{phase}_{loss_name}_acc_ol")(
45 | accuracy(logit_ol.argmax(-1), label_ol, ignore_index=0)
46 | )
47 |
48 | pl_module.log(
49 | f"{loss_name}/{phase}/total_loss",
50 | total_loss,
51 | batch_size=batch_size,
52 | prog_bar=True,
53 | sync_dist=True,
54 | )
55 | pl_module.log(
56 | f"{loss_name}/{phase}/loss_topo",
57 | loss_topo,
58 | batch_size=batch_size,
59 | prog_bar=False,
60 | sync_dist=True,
61 | )
62 | pl_module.log(
63 | f"{loss_name}/{phase}/loss_mc",
64 | loss_mc,
65 | batch_size=batch_size,
66 | prog_bar=False,
67 | sync_dist=True,
68 | )
69 | pl_module.log(
70 | f"{loss_name}/{phase}/loss_ol",
71 | loss_ol,
72 | batch_size=batch_size,
73 | prog_bar=False,
74 | sync_dist=True,
75 | )
76 | pl_module.log(
77 | f"{loss_name}/{phase}/acc_topo",
78 | acc_topo,
79 | batch_size=batch_size,
80 | prog_bar=True,
81 | sync_dist=True,
82 | )
83 | pl_module.log(
84 | f"{loss_name}/{phase}/acc_mc",
85 | acc_mc,
86 | batch_size=batch_size,
87 | prog_bar=True,
88 | sync_dist=True,
89 | )
90 | pl_module.log(
91 | f"{loss_name}/{phase}/acc_ol",
92 | acc_ol,
93 | batch_size=batch_size,
94 | prog_bar=True,
95 | sync_dist=True,
96 | )
97 | return ret
98 |
--------------------------------------------------------------------------------
/mofreinforce/generator/transformer/__init__.py:
--------------------------------------------------------------------------------
1 | from .transformer import Transformer, Encoder, Decoder
--------------------------------------------------------------------------------
/mofreinforce/generator/transformer/attention.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 |
5 |
6 | def attention(query, key, value, mask=None, dropout=None):
7 | """
8 |
9 | :param query: [B, num_heads, seq_len, hid_dim//num_heads]
10 | :param key: [B, num_heads, seq_len, hid_dim//num_heads]
11 | :param value: [B, num_heads, seq_len, hid_dim//num_heads]
12 | :param mask: [B, 1, 1, seq_len]
13 | :param dropout: (float) dropout_rate
14 | :return: [B, num_heads, seq_len, hid_dim//num_heads]
15 | """
16 | "Compute 'Scaled Dot Product Attention'"
17 | d_k = query.size(-1) # hid_dim
18 | scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k) # [B, num_heads, seq_len, seq_len]
19 | if mask is not None:
20 | scores = scores.masked_fill(mask == 0, -float('inf')) # [B, num_heads, seq_len, seq_len]
21 |
22 | p_attn = scores.softmax(dim=-1) # [B, num_heads, seq_len, seq_len]
23 |
24 | if dropout is not None:
25 | p_attn = dropout(p_attn)
26 | return torch.matmul(p_attn, value), p_attn # [B, num_heads, seq_len, hid_dim//num_heads]
27 |
28 |
29 | class MultiHeadedAttention(nn.Module):
30 | def __init__(self, h, d_model, dropout_rate=0.1):
31 | super(MultiHeadedAttention, self).__init__()
32 | assert d_model % h == 0
33 | # We assume d_v always equals d_k
34 | self.d_k = d_model // h
35 | self.h = h
36 | self.layer_q = nn.Linear(d_model, d_model)
37 | self.layer_k = nn.Linear(d_model, d_model)
38 | self.layer_v = nn.Linear(d_model, d_model)
39 | self.proj = nn.Linear(d_model, d_model)
40 |
41 | self.dropout = nn.Dropout(p=dropout_rate)
42 |
43 | def forward(self, q, k, v, mask=None):
44 | # query : [B, seq_len, hid_dim]
45 | # mask : [B, 1, seq_len] for src, [B, seq_len, seq_len] for tgt
46 | if mask is not None:
47 | mask = mask.unsqueeze(1) # [B, 1, 1, seq_len]
48 | batch_size = q.size(0)
49 |
50 | # [B, seq_len, num_heads, hid_dim//num_heads] -> [B, num_heads, seq_len, hid_dim//num_heads]
51 | q = self.layer_q(q).view(batch_size, -1, self.h, self.d_k).transpose(1, 2)
52 | k = self.layer_k(k).view(batch_size, -1, self.h, self.d_k).transpose(1, 2)
53 | v = self.layer_k(v).view(batch_size, -1, self.h, self.d_k).transpose(1, 2)
54 |
55 | # 2) Apply attention on all the projected vectors in batch.
56 | x, p_attn = attention(
57 | q, k, v, mask=mask, dropout=self.dropout
58 | ) # [B, num_heads, seq_len, hid_dim//num_heads]
59 |
60 | # 3) "Concat" using a view and apply a final linear.
61 | x = (
62 | x.transpose(1, 2) # [B, seq_len, num_heads, hid_dim//num_heads]
63 | .contiguous()
64 | .view(batch_size, -1, self.h * self.d_k) # [B, seq_len, hid_dim]
65 | )
66 |
67 | x = self.proj(x)
68 | return x, p_attn
69 |
--------------------------------------------------------------------------------
/mofreinforce/generator/transformer/layers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | class MultiHeadAttentionLayer(nn.Module):
5 | def __init__(self, hid_dim, n_heads, dropout):
6 | super().__init__()
7 |
8 | assert hid_dim % n_heads == 0
9 |
10 | self.hid_dim = hid_dim
11 | self.n_heads = n_heads
12 | self.head_dim = hid_dim // n_heads
13 |
14 | self.fc_q = nn.Linear(hid_dim, hid_dim)
15 | self.fc_k = nn.Linear(hid_dim, hid_dim)
16 | self.fc_v = nn.Linear(hid_dim, hid_dim)
17 |
18 | self.fc_o = nn.Linear(hid_dim, hid_dim)
19 |
20 | self.dropout = nn.Dropout(dropout)
21 |
22 | self.scale = torch.sqrt(torch.FloatTensor([self.head_dim]))
23 |
24 | def forward(self, query, key, value, mask=None):
25 | """
26 | :param query: [B, seq_len, hid_dim]
27 | :param key: [B, seq_len, hid_dim]
28 | :param value: [B, seq_len, hid_dim]
29 | :param mask: [B, 1, 1, src_len] for src_mask, [B, 1, trg_len, trg_len] for trg_mask
30 | :return: [B, seq_len, hid_dim], [B, n_heads, seq_len, seq_len]
31 | """
32 | batch_size = query.shape[0]
33 |
34 | Q = self.fc_q(query) # [batch size, query len, hid dim]
35 | K = self.fc_k(key) # [batch size, key len, hid dim]
36 | V = self.fc_v(value) # [batch size, value len, hid dim]
37 |
38 | Q = Q.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
39 | K = K.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
40 | V = V.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
41 | # [batch size, n heads, query len, head dim]
42 | # [batch size, n heads, key len, head dim]
43 | # [batch size, n heads, value len, head dim]
44 | energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale.to(query.device)
45 | # energy = [batch size, n heads, query len, key len]
46 | if mask is not None:
47 | energy = energy.masked_fill(mask == 0, -1e10)
48 | attention = torch.softmax(energy, dim=-1) # [batch size, n heads, query len, key len]
49 | x = torch.matmul(self.dropout(attention), V) # [batch size, n heads, query len, head dim]
50 | x = x.permute(0, 2, 1, 3).contiguous() # [batch size, query len, n heads, head dim]
51 | x = x.view(batch_size, -1, self.hid_dim) # [batch size, query len, hid dim]
52 | x = self.fc_o(x) # [batch size, query len, hid dim]
53 | return x, attention
54 |
55 |
56 | class PositionwiseFeedforwardLayer(nn.Module):
57 | def __init__(self, hid_dim, pf_dim, dropout):
58 | super().__init__()
59 |
60 | self.fc_1 = nn.Linear(hid_dim, pf_dim)
61 | self.fc_2 = nn.Linear(pf_dim, hid_dim)
62 |
63 | self.dropout = nn.Dropout(dropout)
64 |
65 | def forward(self, x):
66 | """
67 | :param x: [B, seq_len, hid dim]
68 | :return: [B, seq_len, hid_dim]
69 | """
70 | x = self.dropout(torch.relu(self.fc_1(x))) # [batch size, seq len, pf dim]
71 | x = self.fc_2(x) # [batch size, seq len, hid dim]
72 | return x
73 |
74 |
75 | class EncoderLayer(nn.Module):
76 | def __init__(self,
77 | hid_dim,
78 | n_heads,
79 | pf_dim,
80 | dropout,
81 | ):
82 | super().__init__()
83 |
84 | self.self_attn_layer_norm = nn.LayerNorm(hid_dim)
85 | self.ff_layer_norm = nn.LayerNorm(hid_dim)
86 | self.self_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout)
87 | self.positionwise_feedforward = PositionwiseFeedforwardLayer(hid_dim,
88 | pf_dim,
89 | dropout)
90 | self.dropout = nn.Dropout(dropout)
91 |
92 | def forward(self, src, src_mask):
93 | """
94 | :param src: [B, src_len, hid_dim]
95 | :param src_mask: [B, 1, 1, src_len]
96 | :return: [B, src_len, hid_dim]
97 | """
98 | # self attention
99 | _src, _ = self.self_attention(src, src, src, src_mask) # [batch_size, src_len, hid_dim]
100 |
101 | # dropout, residual connection and layer norm
102 | src = self.self_attn_layer_norm(src + self.dropout(_src)) # [batch_size, src_len, hid_dim]
103 |
104 | # positionwise feedforward
105 | _src = self.positionwise_feedforward(src) # [batch_size, src_len, hid_dim]
106 |
107 | # dropout, residual and layer norm
108 | src = self.ff_layer_norm(src + self.dropout(_src)) # [batch_size, src_len, hid_dim]
109 | return src
110 |
111 |
112 | class DecoderLayer(nn.Module):
113 | def __init__(self,
114 | hid_dim,
115 | n_heads,
116 | pf_dim,
117 | dropout,
118 | ):
119 | super().__init__()
120 |
121 | self.self_attn_layer_norm = nn.LayerNorm(hid_dim)
122 | self.enc_attn_layer_norm = nn.LayerNorm(hid_dim)
123 | self.ff_layer_norm = nn.LayerNorm(hid_dim)
124 | self.self_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout)
125 | self.encoder_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout)
126 | self.positionwise_feedforward = PositionwiseFeedforwardLayer(hid_dim,
127 | pf_dim,
128 | dropout)
129 | self.dropout = nn.Dropout(dropout)
130 |
131 | def forward(self, trg, enc_src, trg_mask, src_mask):
132 | """
133 | :param trg: [B, trg_len, hid_dim]
134 | :param enc_src: [B, src_len, hid_dim]
135 | :param trg_mask: [B, 1, trg_len, trg_len]
136 | :param src_mask: [B, 1, 1, seq_len]
137 | :return: [B, trg_len, hid_dim], [B, n_heads, trg_len, src_len]
138 | """
139 |
140 | # self attention
141 | _trg, _ = self.self_attention(trg, trg, trg, trg_mask) # [batch size, trg len, hid dim]
142 |
143 | # dropout, residual connection and layer norm
144 | trg = self.self_attn_layer_norm(trg + self.dropout(_trg)) # [batch size, trg len, hid dim]
145 |
146 | # encoder attention
147 | _trg, attention = self.encoder_attention(trg, enc_src, enc_src, src_mask)
148 | # [batch size, trg len, hid dim], [batch size, n heads, trg len, src len]
149 |
150 | # dropout, residual connection and layer norm
151 | trg = self.enc_attn_layer_norm(trg + self.dropout(_trg)) # [batch size, trg len, hid dim]
152 |
153 | # positionwise feedforward
154 | _trg = self.positionwise_feedforward(trg) # [batch size, trg len, hid dim]
155 |
156 | # dropout, residual and layer norm
157 | trg = self.ff_layer_norm(trg + self.dropout(_trg)) # [batch size, trg len, hid dim]
158 |
159 | return trg, attention
--------------------------------------------------------------------------------
/mofreinforce/generator/transformer/transformer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | from .layers import EncoderLayer, DecoderLayer
5 |
6 |
7 | class Transformer(nn.Module):
8 | """
9 | Trnasformer for MOF Generator
10 | """
11 | def __init__(self,
12 | input_dim,
13 | output_dim,
14 | topo_dim,
15 | mc_dim,
16 | hid_dim,
17 | n_layers,
18 | n_heads,
19 | pf_dim,
20 | dropout,
21 | max_len,
22 | src_pad_idx,
23 | trg_pad_idx,
24 | ):
25 | super().__init__()
26 |
27 | self.encoder = Encoder(
28 | input_dim=input_dim,
29 | mc_dim=mc_dim,
30 | hid_dim=hid_dim,
31 | n_layers=n_layers,
32 | n_heads=n_heads,
33 | pf_dim=pf_dim,
34 | dropout=dropout,
35 | max_len=max_len,
36 | )
37 | self.decoder = Decoder(
38 | output_dim=output_dim,
39 | topo_dim=topo_dim,
40 | mc_dim=mc_dim,
41 | hid_dim=hid_dim,
42 | n_layers=n_layers,
43 | n_heads=n_heads,
44 | pf_dim=pf_dim,
45 | dropout=dropout,
46 | max_len=max_len,
47 | )
48 | self.src_pad_idx = src_pad_idx
49 | self.trg_pad_idx = trg_pad_idx
50 |
51 | self.apply(self.init_weights)
52 |
53 | def init_weights(self, m):
54 | if hasattr(m, 'weight') and m.weight.dim() > 1:
55 | nn.init.xavier_uniform_(m.weight.data)
56 |
57 | def make_src_mask(self, src):
58 | """
59 | make padding mask for src
60 | :param src: [B, src_len]
61 | :return: [B, 1, 1, src_len]
62 | """
63 | src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)
64 | return src_mask
65 |
66 | def make_trg_mask(self, trg):
67 | """
68 | make padding and look-ahead mask for trg
69 | :param trg: [B, trg_len]
70 | :return: [B, 1, trg_len, trg_len]
71 | """
72 | trg_pad_mask = (trg != self.trg_pad_idx).unsqueeze(1).unsqueeze(2) # [batch size, 1, 1, trg len]
73 |
74 | trg_len = trg.shape[1]
75 | trg_sub_mask = torch.tril(torch.ones((trg_len, trg_len))).bool() # [trg len, trg len]
76 |
77 | trg_mask = trg_pad_mask & trg_sub_mask.to(trg_pad_mask.device) # [batch size, 1, trg len, trg len]
78 |
79 | return trg_mask
80 |
81 | def forward(self, src, trg):
82 | """
83 | :param src: [B, src_len]
84 | :param trg: [B, trg_len]
85 | :return: [B, trg_len, vocab_dim] [B, n_heads, trg_len, src_len]
86 | """
87 | src_mask = self.make_src_mask(src) # [batch size, 1, 1, src len]
88 | trg_mask = self.make_trg_mask(trg) # [batch size, 1, trg len, trg len]
89 |
90 | # encoder
91 | enc_src = self.encoder(src, src_mask) # [batch size, src len, hid dim]
92 |
93 | # decoder
94 | output = self.decoder(trg, enc_src, trg_mask, src_mask)
95 | # [batch size, trg len, output dim], [batch size, n heads, trg len, src len]
96 | return output
97 |
98 |
99 | class Encoder(nn.Module):
100 | def __init__(self,
101 | input_dim,
102 | mc_dim,
103 | hid_dim,
104 | n_layers,
105 | n_heads,
106 | pf_dim,
107 | dropout,
108 | max_len,
109 | max_conn=10):
110 | super().__init__()
111 | self.mc_embedding = nn.Embedding(mc_dim, hid_dim)
112 | self.num_embedding = nn.Embedding(max_conn, hid_dim) # num_conn of ol
113 | self.vocab_embedding = nn.Embedding(input_dim, hid_dim)
114 | self.pos_embedding = nn.Embedding(max_len, hid_dim)
115 |
116 | self.layers = nn.ModuleList([EncoderLayer(hid_dim,
117 | n_heads,
118 | pf_dim,
119 | dropout,
120 | )
121 | for _ in range(n_layers)])
122 |
123 | self.dropout = nn.Dropout(dropout)
124 |
125 | self.scale = torch.sqrt(torch.FloatTensor([hid_dim]))
126 |
127 | def forward(self, src, src_mask):
128 | """
129 | :param src: [B, src_len]
130 | :param src_mask: [B, 1, 1, src_len]
131 | :param num_conn: [B]
132 | :return: [B, src_len, hid_dim]
133 | """
134 | batch_size = src.shape[0]
135 | src_len = src.shape[1]
136 |
137 | pos = torch.arange(0, src_len).unsqueeze(0).repeat(batch_size, 1).to(src.device) # [batch size, src len]
138 |
139 | src = torch.concat(
140 | [
141 | self.mc_embedding(src[:, 0].unsqueeze(-1)),
142 | self.num_embedding(src[:, 1].unsqueeze(-1)),
143 | self.vocab_embedding(src[:, 2:]), # [batch size, src_len, hid dim]
144 | ],
145 | dim=1
146 | )
147 |
148 | src = self.dropout((src * self.scale.to(src.device)) + self.pos_embedding(pos))
149 | # [batch size, src len, hid dim]
150 |
151 | for layer in self.layers:
152 | src = layer(src, src_mask) # [batch size, src len, hid dim]
153 |
154 | return src
155 |
156 |
157 | class Decoder(nn.Module):
158 | def __init__(self,
159 | output_dim,
160 | topo_dim,
161 | mc_dim,
162 | hid_dim,
163 | n_layers,
164 | n_heads,
165 | pf_dim,
166 | dropout,
167 | max_len,
168 | ):
169 | super().__init__()
170 |
171 | self.topo_embedding = nn.Embedding(topo_dim, hid_dim)
172 | self.mc_embedding = nn.Embedding(mc_dim, hid_dim)
173 | self.vocab_embedding = nn.Embedding(output_dim, hid_dim)
174 | self.pos_embedding = nn.Embedding(max_len, hid_dim)
175 |
176 | self.layers = nn.ModuleList([DecoderLayer(hid_dim,
177 | n_heads,
178 | pf_dim,
179 | dropout,
180 | )
181 | for _ in range(n_layers)])
182 |
183 | self.fc_out_topo = nn.Linear(hid_dim, topo_dim)
184 | self.fc_out_mc = nn.Linear(hid_dim, mc_dim)
185 | self.fc_out_ol = nn.Linear(hid_dim, output_dim)
186 |
187 | self.dropout = nn.Dropout(dropout)
188 |
189 | self.scale = torch.sqrt(torch.FloatTensor([hid_dim]))
190 |
191 | def forward(self, trg, enc_src, trg_mask, src_mask):
192 | """
193 | if len(trg) == 1: topo_embedding
194 | elif len(trg) == 2: mc_embedding
195 | else: vocab_embedding
196 | :param trg: [B, trg_len]
197 | :param enc_src: [B, src_len, hid_dim]
198 | :param trg_mask: [B, 1, trg_len, trg_len]
199 | :param src_mask: [B, 1, 1, src_len]
200 | :return: [B, trg_len, vocab_dim] [B, n_heads, trg_len, src_len]
201 | """
202 |
203 | batch_size = trg.shape[0]
204 | trg_len = trg.shape[1]
205 |
206 | pos = torch.arange(0, trg_len).unsqueeze(0).repeat(batch_size, 1).to(trg.device) # [batch size, trg len]
207 |
208 | if trg_len == 1: # [SOS]
209 | tok_emb = self.vocab_embedding(trg[:, 0].unsqueeze(-1))
210 | elif trg_len == 2: # [SOS, topo]
211 | tok_emb = torch.concat(
212 | [
213 | self.vocab_embedding(trg[:, 0].unsqueeze(-1)),
214 | self.topo_embedding(trg[:, 1].unsqueeze(-1)),
215 | ],
216 | dim=1
217 | )
218 | elif trg_len == 3: # [SOS, topo, mc]
219 | tok_emb = torch.concat(
220 | [
221 | self.vocab_embedding(trg[:, 0].unsqueeze(-1)),
222 | self.topo_embedding(trg[:, 1].unsqueeze(-1)),
223 | self.mc_embedding(trg[:, 2].unsqueeze(-1)),
224 | ],
225 | dim=1
226 | )
227 | else: # [SOS, topo, mc, ol]
228 | tok_emb = torch.concat(
229 | [
230 | self.vocab_embedding(trg[:, 0].unsqueeze(-1)),
231 | self.topo_embedding(trg[:, 1].unsqueeze(-1)),
232 | self.mc_embedding(trg[:, 2].unsqueeze(-1)),
233 | self.vocab_embedding(trg[:, 3:])
234 | ],
235 | dim=1
236 | )
237 |
238 | # [batch size, trg len, hid dim]
239 | trg = self.dropout((tok_emb * self.scale.to(trg.device)) + self.pos_embedding(pos))
240 |
241 | for layer in self.layers:
242 | trg, attention = layer(trg, enc_src, trg_mask, src_mask)
243 | # [batch size, trg len, hid dim], [batch size, n heads, trg len, src len]
244 |
245 | output = {}
246 | if trg_len == 1:
247 | output.update(
248 | {
249 | "output_topo" : self.fc_out_topo(trg[:, 0]), # [batch size, topo_dim]
250 | }
251 | )
252 | elif trg_len == 2:
253 | output.update(
254 | {
255 | "output_topo" : self.fc_out_topo(trg[:, 0]), # [batch size, topo_dim],
256 | "output_mc" : self.fc_out_mc(trg[:, 1]), # [batch size, mc_dim]
257 | }
258 | )
259 | else:
260 | output.update(
261 | {
262 | "output_topo" : self.fc_out_topo(trg[:, 0]), # [batch size, topo_dim],
263 | "output_mc" : self.fc_out_mc(trg[:, 1]), # [batch size, mc_dim]
264 | "output_ol" : self.fc_out_ol(trg[:, 2:]) # [batch size, trg len, output dim]
265 | }
266 | )
267 | output.update({"attention" : attention})
268 | return output
269 |
--------------------------------------------------------------------------------
/mofreinforce/libs/selfies/README.md:
--------------------------------------------------------------------------------
1 | This SELFIES directory is adapted from the "selfies" repository: https://github.com/aspuru-guzik-group/selfies.git
2 |
3 | The following modifications have been made to the original repository to recognize dummy atoms "*" in SMILES,
4 | which were not provided for in the official "selfies" repository:
5 |
6 | (1) In selfies.utils.smiles_utils.py, the following lines have been added to line 91:
7 | ```python
8 | elif smiles[i] == "*":
9 | token = SMILESToken(bond_idx, i, i + 1,
10 | SMILESTokenTypes.ATOM, smiles[i:i + 1])
11 | ```
12 |
13 | (2) The dummy atom "*" has been added to `ORGANIC_SUBSET` in `selfies.constants.py`:
14 | ```python
15 | ORGANIC_SUBSET = {"*", "B", "C", "N", "O", "S", "P", "F", "Cl", "Br", "I"}
16 | ```
17 |
18 | (3) "*" has been added to the regex rule in line 110 of `selfies.grammar_rules.py`:
19 | ```python
20 | r"([A-Z][a-z]?|\*)" # element symbol
21 | ```
22 |
--------------------------------------------------------------------------------
/mofreinforce/libs/selfies/__init__.py:
--------------------------------------------------------------------------------
1 | # This code is adapted from the "selfies" repository:
2 | # https://github.com/aspuru-guzik-group/selfies.git
3 |
4 | # The code in this file is licensed under the Apache License, Version 2.0:
5 | # https://github.com/aspuru-guzik-group/selfies/blob/master/LICENSE
6 |
7 | # Adapted by Hyunsoo Park, 2022
8 |
9 | #!/usr/bin/env python
10 |
11 | """
12 | SELFIES: a robust representation of semantically constrained graphs with an
13 | example application in chemistry.
14 |
15 | SELFIES (SELF-referencIng Embedded Strings) is a general-purpose,
16 | sequence-based, robust representation of semantically constrained graphs.
17 | It is based on a Chomsky type-2 grammar, augmented with two self-referencing
18 | functions. A main objective is to use SELFIES as direct input into machine
19 | learning models, in particular in generative models, for the generation of
20 | outputs with high validity.
21 |
22 | The code presented here is a concrete application of SELFIES in chemistry, for
23 | the robust representation of molecules.
24 |
25 | Typical usage example:
26 | import selfies as sf
27 |
28 | benzene = "C1=CC=CC=C1"
29 | benzene_selfies = sf.encoder(benzene)
30 | benzene_smiles = sf.decoder(benzene_selfies)
31 |
32 | For comments, bug reports or feature ideas, please send an email to
33 | mario.krenn@utoronto.ca and alan@aspuru.com.
34 | """
35 |
36 | __version__ = "2.1.0"
37 |
38 | __all__ = [
39 | "encoder",
40 | "decoder",
41 | "get_preset_constraints",
42 | "get_semantic_robust_alphabet",
43 | "get_semantic_constraints",
44 | "set_semantic_constraints",
45 | "len_selfies",
46 | "split_selfies",
47 | "get_alphabet_from_selfies",
48 | "selfies_to_encoding",
49 | "batch_selfies_to_flat_hot",
50 | "encoding_to_selfies",
51 | "batch_flat_hot_to_selfies",
52 | "EncoderError",
53 | "DecoderError"
54 | ]
55 |
56 | from .bond_constraints import (
57 | get_preset_constraints,
58 | get_semantic_constraints,
59 | get_semantic_robust_alphabet,
60 | set_semantic_constraints
61 | )
62 | from .decoder import decoder
63 | from .encoder import encoder
64 | from .exceptions import DecoderError, EncoderError
65 | from .utils.encoding_utils import (
66 | batch_flat_hot_to_selfies,
67 | batch_selfies_to_flat_hot,
68 | encoding_to_selfies,
69 | selfies_to_encoding
70 | )
71 | from .utils.selfies_utils import (
72 | get_alphabet_from_selfies,
73 | len_selfies,
74 | split_selfies
75 | )
76 |
--------------------------------------------------------------------------------
/mofreinforce/libs/selfies/bond_constraints.py:
--------------------------------------------------------------------------------
1 | # This code is adapted from the "selfies" repository:
2 | # https://github.com/aspuru-guzik-group/selfies.git
3 |
4 | # The code in this file is licensed under the Apache License, Version 2.0:
5 | # https://github.com/aspuru-guzik-group/selfies/blob/master/LICENSE
6 |
7 | # Adapted by Hyunsoo Park, 2022
8 |
9 | import functools
10 | from itertools import product
11 | from typing import Dict, Set, Union
12 |
13 | from .constants import ELEMENTS, INDEX_ALPHABET
14 |
15 | _DEFAULT_CONSTRAINTS = {
16 | "H": 1, "F": 1, "Cl": 1, "Br": 1, "I": 1,
17 | "B": 3, "B+1": 2, "B-1": 4,
18 | "O": 2, "O+1": 3, "O-1": 1,
19 | "N": 3, "N+1": 4, "N-1": 2,
20 | "C": 4, "C+1": 5, "C-1": 3,
21 | "P": 5, "P+1": 6, "P-1": 4,
22 | "S": 6, "S+1": 7, "S-1": 5,
23 | "?": 8
24 | }
25 |
26 | _PRESET_CONSTRAINTS = {
27 | "default": dict(_DEFAULT_CONSTRAINTS),
28 | "octet_rule": dict(_DEFAULT_CONSTRAINTS),
29 | "hypervalent": dict(_DEFAULT_CONSTRAINTS)
30 | }
31 | _PRESET_CONSTRAINTS["octet_rule"].update(
32 | {"S": 2, "S+1": 3, "S-1": 1, "P": 3, "P+1": 4, "P-1": 2}
33 | )
34 | _PRESET_CONSTRAINTS["hypervalent"].update(
35 | {"Cl": 7, "Br": 7, "I": 7, "N": 5}
36 | )
37 |
38 | _current_constraints = _PRESET_CONSTRAINTS["default"]
39 |
40 |
41 | def get_preset_constraints(name: str) -> Dict[str, int]:
42 | """Returns the preset semantic constraints with the given name.
43 |
44 | Besides the aforementioned default constraints, :mod:`selfies` offers
45 | other preset constraints for convenience; namely, constraints that
46 | enforce the `octet rule `_
47 | and constraints that accommodate `hypervalent molecules
48 | `_.
49 |
50 | The differences between these constraints can be summarized as follows:
51 |
52 | .. table::
53 | :align: center
54 | :widths: auto
55 |
56 | +-----------------+-----------+---+---+-----+-----+---+-----+-----+
57 | | | Cl, Br, I | N | P | P+1 | P-1 | S | S+1 | S-1 |
58 | +-----------------+-----------+---+---+-----+-----+---+-----+-----+
59 | | ``default`` | 1 | 3 | 5 | 6 | 4 | 6 | 7 | 5 |
60 | +-----------------+-----------+---+---+-----+-----+---+-----+-----+
61 | | ``octet_rule`` | 1 | 3 | 3 | 4 | 2 | 2 | 3 | 1 |
62 | +-----------------+-----------+---+---+-----+-----+---+-----+-----+
63 | | ``hypervalent`` | 7 | 5 | 5 | 6 | 4 | 6 | 7 | 5 |
64 | +-----------------+-----------+---+---+-----+-----+---+-----+-----+
65 |
66 | :param name: the preset name: ``default`` or ``octet_rule`` or
67 | ``hypervalent``.
68 | :return: the preset constraints with the specified name, represented
69 | as a dictionary which maps atoms (the keys) to their bonding capacities
70 | (the values).
71 | """
72 |
73 | if name not in _PRESET_CONSTRAINTS:
74 | raise ValueError("unrecognized preset name '{}'".format(name))
75 | return dict(_PRESET_CONSTRAINTS[name])
76 |
77 |
78 | def get_semantic_constraints() -> Dict[str, int]:
79 | """Returns the semantic constraints that :mod:`selfies` is currently
80 | operating on.
81 |
82 | :return: the current semantic constraints, represented as a dictionary
83 | which maps atoms (the keys) to their bonding capacities (the values).
84 | """
85 |
86 | global _current_constraints
87 | return dict(_current_constraints)
88 |
89 |
90 | def set_semantic_constraints(
91 | bond_constraints: Union[str, Dict[str, int]] = "default"
92 | ) -> None:
93 | """Updates the semantic constraints that :mod:`selfies` operates on.
94 |
95 | If the input is a string, the new constraints are taken to be
96 | the preset named ``bond_constraints``
97 | (see :func:`selfies.get_preset_constraints`).
98 |
99 | Otherwise, the input is a dictionary representing the new constraints.
100 | This dictionary maps atoms (the keys) to non-negative bonding
101 | capacities (the values); the atoms are specified by strings
102 | of the form ``E`` or ``E+C`` or ``E-C``,
103 | where ``E`` is an element symbol and ``C`` is a positive integer.
104 | For example, one may have:
105 |
106 | * ``bond_constraints["I-1"] = 0``
107 | * ``bond_constraints["C"] = 4``
108 |
109 | This dictionary must also contain the special ``?`` key, which indicates
110 | the bond capacities of all atoms that are not explicitly listed
111 | in the dictionary.
112 |
113 | :param bond_constraints: the name of a preset, or a dictionary
114 | representing the new semantic constraints.
115 | :return: ``None``.
116 | """
117 |
118 | global _current_constraints
119 |
120 | if isinstance(bond_constraints, str):
121 | _current_constraints = get_preset_constraints(bond_constraints)
122 |
123 | elif isinstance(bond_constraints, dict):
124 |
125 | # error checking
126 | if "?" not in bond_constraints:
127 | raise ValueError("bond_constraints missing '?' as a key")
128 |
129 | for key, value in bond_constraints.items():
130 |
131 | # error checking for keys
132 | j = max(key.find("+"), key.find("-"))
133 | if key == "?":
134 | valid = True
135 | elif j == -1:
136 | valid = (key in ELEMENTS)
137 | else:
138 | valid = (key[:j] in ELEMENTS) and key[j + 1:].isnumeric()
139 | if not valid:
140 | err_msg = "invalid key '{}' in bond_constraints".format(key)
141 | raise ValueError(err_msg)
142 |
143 | # error checking for values
144 | if not (isinstance(value, int) and value >= 0):
145 | err_msg = "invalid value at " \
146 | "bond_constraints['{}'] = {}".format(key, value)
147 | raise ValueError(err_msg)
148 |
149 | _current_constraints = dict(bond_constraints)
150 |
151 | else:
152 | raise ValueError("bond_constraints must be a str or dict")
153 |
154 | # clear cache since we changed alphabet
155 | get_semantic_robust_alphabet.cache_clear()
156 | get_bonding_capacity.cache_clear()
157 |
158 |
159 | @functools.lru_cache()
160 | def get_semantic_robust_alphabet() -> Set[str]:
161 | """Returns a subset of all SELFIES symbols that are constrained
162 | by :mod:`selfies` under the current semantic constraints.
163 |
164 | :return: a subset of all SELFIES symbols that are semantically constrained.
165 | """
166 |
167 | alphabet_subset = set()
168 | bonds = {"": 1, "=": 2, "#": 3}
169 |
170 | # add atomic symbols
171 | for (a, c), (b, m) in product(_current_constraints.items(), bonds.items()):
172 | if (m > c) or (a == "?"):
173 | continue
174 | symbol = "[{}{}]".format(b, a)
175 | alphabet_subset.add(symbol)
176 |
177 | # add branch and ring symbols
178 | for i in range(1, 4):
179 | alphabet_subset.add("[Ring{}]".format(i))
180 | alphabet_subset.add("[=Ring{}]".format(i))
181 | alphabet_subset.add("[Branch{}]".format(i))
182 | alphabet_subset.add("[=Branch{}]".format(i))
183 | alphabet_subset.add("[#Branch{}]".format(i))
184 |
185 | alphabet_subset.update(INDEX_ALPHABET)
186 |
187 | return alphabet_subset
188 |
189 |
190 | @functools.lru_cache()
191 | def get_bonding_capacity(element: str, charge: int) -> int:
192 | """Returns the bonding capacity of a given atom, under the current
193 | semantic constraints.
194 |
195 | :param element: the element of the input atom.
196 | :param charge: the charge of the input atom.
197 | :return: the bonding capacity of the input atom.
198 | """
199 |
200 | key = element
201 | if charge != 0:
202 | key += "{:+}".format(charge)
203 |
204 | if key in _current_constraints:
205 | return _current_constraints[key]
206 | else:
207 | return _current_constraints["?"]
208 |
--------------------------------------------------------------------------------
/mofreinforce/libs/selfies/compatibility.py:
--------------------------------------------------------------------------------
1 | # This code is adapted from the "selfies" repository:
2 | # https://github.com/aspuru-guzik-group/selfies.git
3 |
4 | # The code in this file is licensed under the Apache License, Version 2.0:
5 | # https://github.com/aspuru-guzik-group/selfies/blob/master/LICENSE
6 |
7 | # Adapted by Hyunsoo Park, 2022
8 |
9 | from .utils.smiles_utils import atom_to_smiles, smiles_to_atom
10 |
11 |
12 | def modernize_symbol(symbol):
13 | """Converts a SELFIES symbol from \
32 | Union[str, Tuple[str, List[Tuple[str, List[Tuple[int, str]]]]]]:
33 | """Translates a SELFIES string into its corresponding SMILES string.
34 |
35 | This translation is deterministic but depends on the current semantic
36 | constraints. The output SMILES string is guaranteed to be syntatically
37 | correct and guaranteed to represent a molecule that obeys the
38 | semantic constraints.
39 |
40 | :param selfies: the SELFIES string to be translated.
41 | :param compatible: if ``True``, this function will accept SELFIES strings
42 | containing depreciated symbols from previous releases. However, this
43 | function may behave differently than in previous major relases,
44 | and should not be treated as backard compatible.
45 | Defaults to ``False``.
46 | :param attribute: if ``True``, an attribution map connecting selfies
47 | tokens to smiles tokens is output.
48 | :return: a SMILES string derived from the input SELFIES string.
49 | :raises DecoderError: if the input SELFIES string is malformed.
50 |
51 | :Example:
52 |
53 | >>> import selfies as sf
54 | >>> sf.decoder('[C][=C][F]')
55 | 'C=CF'
56 | """
57 |
58 | if compatible:
59 | msg = "\nselfies.decoder() may behave differently than in previous " \
60 | "major releases. We recommend using SELFIES that are up to date."
61 | warnings.warn(msg, stacklevel=2)
62 |
63 | mol = MolecularGraph(attributable=attribute)
64 |
65 | rings = []
66 | attribution_index = 0
67 | for s in selfies.split("."):
68 | n = _derive_mol_from_symbols(
69 | symbol_iter=enumerate(_tokenize_selfies(s, compatible)),
70 | mol=mol,
71 | selfies=selfies,
72 | max_derive=float("inf"),
73 | init_state=0,
74 | root_atom=None,
75 | rings=rings,
76 | attribute_stack=[] if attribute else None,
77 | attribution_index=attribution_index
78 | )
79 | attribution_index += n
80 | _form_rings_bilocally(mol, rings)
81 | return mol_to_smiles(mol, attribute)
82 |
83 |
84 | def _tokenize_selfies(selfies, compatible):
85 | if isinstance(selfies, str):
86 | symbol_iter = split_selfies(selfies)
87 | elif isinstance(selfies, list):
88 | symbol_iter = selfies
89 | else:
90 | raise ValueError() # should not happen
91 |
92 | try:
93 | for symbol in symbol_iter:
94 | if symbol == "[nop]":
95 | continue
96 | if compatible:
97 | symbol = modernize_symbol(symbol)
98 | yield symbol
99 | except ValueError as err:
100 | raise DecoderError(str(err)) from None
101 |
102 |
103 | def _derive_mol_from_symbols(
104 | symbol_iter, mol, selfies, max_derive,
105 | init_state, root_atom, rings, attribute_stack, attribution_index
106 | ):
107 | n_derived = 0
108 | state = init_state
109 | prev_atom = root_atom
110 |
111 | while (state is not None) and (n_derived < max_derive):
112 |
113 | try: # retrieve next symbol
114 | index, symbol = next(symbol_iter)
115 | n_derived += 1
116 | except StopIteration:
117 | break
118 |
119 | # Case 1: Branch symbol (e.g. [Branch1])
120 | if "ch" == symbol[-4:-2]:
121 |
122 | output = process_branch_symbol(symbol)
123 | if output is None:
124 | _raise_decoder_error(selfies, symbol)
125 | btype, n = output
126 |
127 | if state <= 1:
128 | next_state = state
129 | else:
130 | binit_state, next_state = next_branch_state(btype, state)
131 |
132 | Q = _read_index_from_selfies(symbol_iter, n_symbols=n)
133 | n_derived += n + _derive_mol_from_symbols(
134 | symbol_iter, mol, selfies, (Q + 1),
135 | init_state=binit_state, root_atom=prev_atom, rings=rings,
136 | attribute_stack=attribute_stack +
137 | [Attribution(index + attribution_index, symbol)
138 | ] if attribute_stack is not None else None,
139 | attribution_index=attribution_index
140 | )
141 |
142 | # Case 2: Ring symbol (e.g. [Ring2])
143 | elif "ng" == symbol[-4:-2]:
144 |
145 | output = process_ring_symbol(symbol)
146 | if output is None:
147 | _raise_decoder_error(selfies, symbol)
148 | ring_type, n, stereo = output
149 |
150 | if state == 0:
151 | next_state = state
152 | else:
153 | ring_order, next_state = next_ring_state(ring_type, state)
154 | bond_info = (ring_order, stereo)
155 |
156 | Q = _read_index_from_selfies(symbol_iter, n_symbols=n)
157 | n_derived += n
158 | lidx = max(0, prev_atom.index - (Q + 1))
159 | rings.append((mol.get_atom(lidx), prev_atom, bond_info))
160 |
161 | # Case 3: [epsilon]
162 | elif "eps" in symbol:
163 | next_state = 0 if (state == 0) else None
164 |
165 | # Case 4: regular symbol (e.g. [N], [=C], [F])
166 | else:
167 |
168 | output = process_atom_symbol(symbol)
169 | if output is None:
170 | _raise_decoder_error(selfies, symbol)
171 | (bond_order, stereo), atom = output
172 | cap = atom.bonding_capacity
173 |
174 | bond_order, next_state = next_atom_state(bond_order, cap, state)
175 | if bond_order == 0:
176 | if state == 0:
177 | o = mol.add_atom(atom, True)
178 | mol.add_attribution(
179 | o, attribute_stack +
180 | [Attribution(index + attribution_index, symbol)]
181 | if attribute_stack is not None else None)
182 | else:
183 | o = mol.add_atom(atom)
184 | mol.add_attribution(
185 | o, attribute_stack +
186 | [Attribution(index + attribution_index, symbol)]
187 | if attribute_stack is not None else None)
188 | src, dst = prev_atom.index, atom.index
189 | o = mol.add_bond(src=src, dst=dst,
190 | order=bond_order, stereo=stereo)
191 | mol.add_attribution(
192 | o, attribute_stack +
193 | [Attribution(index + attribution_index, symbol)]
194 | if attribute_stack is not None else None)
195 | prev_atom = atom
196 |
197 | if next_state is None:
198 | break
199 | state = next_state
200 |
201 | while n_derived < max_derive: # consume remaining tokens
202 | try:
203 | next(symbol_iter)
204 | n_derived += 1
205 | except StopIteration:
206 | break
207 |
208 | return n_derived
209 |
210 |
211 | def _raise_decoder_error(selfies, invalid_symbol):
212 | err_msg = "invalid symbol '{}'\n\tSELFIES: {}".format(
213 | invalid_symbol, selfies
214 | )
215 | raise DecoderError(err_msg)
216 |
217 |
218 | def _read_index_from_selfies(symbol_iter, n_symbols):
219 | index_symbols = []
220 | for _ in range(n_symbols):
221 | try:
222 | index_symbols.append(next(symbol_iter)[-1])
223 | except StopIteration:
224 | index_symbols.append(None)
225 | return get_index_from_selfies(*index_symbols)
226 |
227 |
228 | def _form_rings_bilocally(mol, rings):
229 | rings_made = [0] * len(mol)
230 |
231 | for latom, ratom, bond_info in rings:
232 | lidx, ridx = latom.index, ratom.index
233 |
234 | if lidx == ridx: # ring to the same atom forbidden
235 | continue
236 |
237 | order, (lstereo, rstereo) = bond_info
238 | lfree = latom.bonding_capacity - mol.get_bond_count(lidx)
239 | rfree = ratom.bonding_capacity - mol.get_bond_count(ridx)
240 |
241 | if lfree <= 0 or rfree <= 0:
242 | continue # no room for ring bond
243 | order = min(order, lfree, rfree)
244 |
245 | if mol.has_bond(a=lidx, b=ridx):
246 | bond = mol.get_dirbond(src=lidx, dst=ridx)
247 | new_order = min(order + bond.order, 3)
248 | mol.update_bond_order(a=lidx, b=ridx, new_order=new_order)
249 |
250 | else:
251 | mol.add_ring_bond(
252 | a=lidx, a_stereo=lstereo, a_pos=rings_made[lidx],
253 | b=ridx, b_stereo=rstereo, b_pos=rings_made[ridx],
254 | order=order
255 | )
256 | rings_made[lidx] += 1
257 | rings_made[ridx] += 1
258 |
--------------------------------------------------------------------------------
/mofreinforce/libs/selfies/encoder.py:
--------------------------------------------------------------------------------
1 | # This code is adapted from the "selfies" repository:
2 | # https://github.com/aspuru-guzik-group/selfies.git
3 |
4 | # The code in this file is licensed under the Apache License, Version 2.0:
5 | # https://github.com/aspuru-guzik-group/selfies/blob/master/LICENSE
6 |
7 | # Adapted by Hyunsoo Park, 2022
8 |
9 | from .exceptions import EncoderError, SMILESParserError
10 | from .grammar_rules import get_selfies_from_index
11 | from .utils.smiles_utils import (
12 | atom_to_smiles,
13 | bond_to_smiles,
14 | smiles_to_mol
15 | )
16 |
17 | from .mol_graph import AttributionMap
18 |
19 |
20 | def encoder(smiles: str, strict: bool = True, attribute: bool = False) -> str:
21 | """Translates a SMILES string into its corresponding SELFIES string.
22 |
23 | This translation is deterministic and does not depend on the
24 | current semantic constraints. Additionally, it preserves the atom order
25 | of the input SMILES string; thus, one could generate randomized SELFIES
26 | strings by generating randomized SMILES strings, and then translating them.
27 |
28 | By nature of SELFIES, it is impossible to represent molecules that
29 | violate the current semantic constraints as SELFIES strings.
30 | Thus, we provide the ``strict`` flag to guard against such cases. If
31 | ``strict=True``, then this function will raise a
32 | :class:`selfies.EncoderError` if the input SMILES string represents
33 | a molecule that violates the semantic constraints. If
34 | ``strict=False``, then this function will not raise any error; however,
35 | calling :func:`selfies.decoder` on a SELFIES string generated this
36 | way will *not* be guaranteed to recover a SMILES string representing
37 | the original molecule.
38 |
39 | :param smiles: the SMILES string to be translated. It is recommended to
40 | use RDKit to check that the strings passed into this function
41 | are valid SMILES strings.
42 | :param strict: if ``True``, this function will check that the
43 | input SMILES string obeys the semantic constraints.
44 | Defaults to ``True``.
45 | :param attribute: if an attribution should be returned
46 | :return: a SELFIES string translated from the input SMILES string if
47 | attribute is ``False``, otherwise a tuple is returned of
48 | SELFIES string and attribution list.
49 | :raises EncoderError: if the input SMILES string is invalid,
50 | cannot be kekulized, or violates the semantic constraints with
51 | ``strict=True``.
52 |
53 | :Example:
54 |
55 | >>> import selfies as sf
56 | >>> sf.encoder("C=CF")
57 | '[C][=C][F]'
58 |
59 | .. note:: This function does not currently support SMILES with:
60 |
61 | * The wildcard symbol ``*``.
62 | * The quadruple bond symbol ``$``.
63 | * Chirality specifications other than ``@`` and ``@@``.
64 | * Ring bonds across a dot symbol (e.g. ``c1cc([O-].[Na+])ccc1``) or
65 | ring bonds between atoms that are over 4000 atoms apart.
66 |
67 | Although SELFIES does not have aromatic symbols, this function
68 | *does* support aromatic SMILES strings by internally kekulizing them
69 | before translation.
70 | """
71 |
72 | try:
73 | mol = smiles_to_mol(smiles, attributable=attribute)
74 | except SMILESParserError as err:
75 | err_msg = "failed to parse input\n\tSMILES: {}".format(smiles)
76 | raise EncoderError(err_msg) from err
77 |
78 | if not mol.kekulize():
79 | err_msg = "kekulization failed\n\tSMILES: {}".format(smiles)
80 | raise EncoderError(err_msg)
81 |
82 | if strict:
83 | _check_bond_constraints(mol, smiles)
84 |
85 | # invert chirality of atoms where necessary,
86 | # such that they are restored when the SELFIES is decoded
87 | for atom in mol.get_atoms():
88 | if ((atom.chirality is not None)
89 | and mol.has_out_ring_bond(atom.index)
90 | and _should_invert_chirality(mol, atom)):
91 | atom.invert_chirality()
92 |
93 | fragments = []
94 | attribution_maps = []
95 | attribution_index = 0
96 | for root in mol.get_roots():
97 | derived = list(_fragment_to_selfies(
98 | mol, None, root, attribution_maps, attribution_index))
99 | attribution_index += len(derived)
100 | fragments.append("".join(derived))
101 | # trim attribution map of empty tokens
102 | attribution_maps = [a for a in attribution_maps if a.token]
103 | result = ".".join(fragments), attribution_maps
104 | return result if attribute else result[0]
105 |
106 |
107 | def _check_bond_constraints(mol, smiles):
108 | errors = []
109 |
110 | for atom in mol.get_atoms():
111 | bond_cap = atom.bonding_capacity
112 | bond_count = mol.get_bond_count(atom.index)
113 | if bond_count > bond_cap:
114 | errors.append((atom_to_smiles(atom), bond_count, bond_cap))
115 |
116 | if errors:
117 | err_msg = "input violates the currently-set semantic constraints\n" \
118 | "\tSMILES: {}\n" \
119 | "\tErrors:\n".format(smiles)
120 | for e in errors:
121 | err_msg += "\t[{:} with {} bond(s) - " \
122 | "a max. of {} bond(s) was specified]\n".format(*e)
123 | raise EncoderError(err_msg)
124 |
125 |
126 | def _should_invert_chirality(mol, atom):
127 | out_bonds = mol.get_out_dirbonds(atom.index)
128 |
129 | # 1. rings whose right number are bonded to this atom (e.g. ...1...X1)
130 | # 2. rings whose left number are bonded to this atom (e.g. X1...1...)
131 | # 3. branches and other (e.g. X(...)...)
132 | partition = [[], [], []]
133 | for i, bond in enumerate(out_bonds):
134 | if not bond.ring_bond:
135 | partition[2].append(i)
136 | elif bond.src < bond.dst:
137 | partition[1].append(i)
138 | else:
139 | partition[0].append(i)
140 | partition[1].sort(key=lambda x: out_bonds[x].dst)
141 |
142 | # construct permutation
143 | perm = partition[0] + partition[1] + partition[2]
144 | count = 0
145 | for i in range(len(perm)):
146 | for j in range(i + 1, len(perm)):
147 | if perm[i] > perm[j]:
148 | count += 1
149 | return count % 2 != 0 # if odd permutation, should invert chirality
150 |
151 |
152 | def _fragment_to_selfies(mol, bond_into_root, root,
153 | attribution_maps, attribution_index=0):
154 | derived = []
155 |
156 | bond_into_curr, curr = bond_into_root, root
157 | while True:
158 | curr_atom = mol.get_atom(curr)
159 | token = _atom_to_selfies(bond_into_curr, curr_atom)
160 | derived.append(token)
161 |
162 | attribution_maps.append(AttributionMap(
163 | len(derived) - 1 + attribution_index,
164 | token, mol.get_attribution(curr_atom)))
165 |
166 | out_bonds = mol.get_out_dirbonds(curr)
167 | for i, bond in enumerate(out_bonds):
168 |
169 | if bond.ring_bond:
170 | if bond.src < bond.dst:
171 | continue
172 |
173 | rev_bond = mol.get_dirbond(src=bond.dst, dst=bond.src)
174 | ring_len = bond.src - bond.dst
175 | Q_as_symbols = get_selfies_from_index(ring_len - 1)
176 | ring_symbol = "[{}Ring{}]".format(
177 | _ring_bonds_to_selfies(rev_bond, bond),
178 | len(Q_as_symbols)
179 | )
180 |
181 | derived.append(ring_symbol)
182 | attribution_maps.append(AttributionMap(
183 | len(derived) - 1 + attribution_index,
184 | ring_symbol, mol.get_attribution(bond)))
185 | for symbol in Q_as_symbols:
186 | derived.append(symbol)
187 | attribution_maps.append(AttributionMap(
188 | len(derived) - 1 + attribution_index,
189 | symbol, mol.get_attribution(bond)))
190 |
191 | elif i == len(out_bonds) - 1:
192 | bond_into_curr, curr = bond, bond.dst
193 |
194 | else:
195 | # start, end are so we can go back and
196 | # correct offset from branch symbol in
197 | # branch tokens
198 | start = len(attribution_maps)
199 | branch = _fragment_to_selfies(
200 | mol, bond, bond.dst, attribution_maps, len(derived))
201 | Q_as_symbols = get_selfies_from_index(len(branch) - 1)
202 | branch_symbol = "[{}Branch{}]".format(
203 | _bond_to_selfies(bond, show_stereo=False),
204 | len(Q_as_symbols)
205 | )
206 | end = len(attribution_maps)
207 |
208 | derived.append(branch_symbol)
209 | for symbol in Q_as_symbols:
210 | derived.append(symbol)
211 | attribution_maps.append(AttributionMap(
212 | len(derived) - 1 + attribution_index,
213 | symbol, mol.get_attribution(bond)))
214 |
215 | # account for branch symbol because it is inserted after
216 | for j in range(start, end):
217 | attribution_maps[j].index += len(Q_as_symbols) + 1
218 | attribution_maps.append(AttributionMap(
219 | len(derived) - 1 + attribution_index,
220 | branch_symbol, mol.get_attribution(bond)))
221 |
222 | derived.extend(branch)
223 |
224 | # end of chain
225 | if (not out_bonds) or out_bonds[-1].ring_bond:
226 | break
227 | return derived
228 |
229 |
230 | def _bond_to_selfies(bond, show_stereo=True):
231 | if not show_stereo and (bond.order == 1):
232 | return ""
233 | return bond_to_smiles(bond)
234 |
235 |
236 | def _ring_bonds_to_selfies(lbond, rbond):
237 | assert lbond.order == rbond.order
238 |
239 | if (lbond.order != 1) or all(b.stereo is None for b in (lbond, rbond)):
240 | return _bond_to_selfies(lbond, show_stereo=False)
241 | else:
242 | bond_char = "-" if (lbond.stereo is None) else lbond.stereo
243 | bond_char += "-" if (rbond.stereo is None) else rbond.stereo
244 | return bond_char
245 |
246 |
247 | def _atom_to_selfies(bond, atom):
248 | assert not atom.is_aromatic
249 | bond_char = "" if (bond is None) else _bond_to_selfies(bond)
250 | return "[{}{}]".format(bond_char, atom_to_smiles(atom, brackets=False))
251 |
--------------------------------------------------------------------------------
/mofreinforce/libs/selfies/exceptions.py:
--------------------------------------------------------------------------------
1 | # This code is adapted from the "selfies" repository:
2 | # https://github.com/aspuru-guzik-group/selfies.git
3 |
4 | # The code in this file is licensed under the Apache License, Version 2.0:
5 | # https://github.com/aspuru-guzik-group/selfies/blob/master/LICENSE
6 |
7 | # Adapted by Hyunsoo Park, 2022
8 |
9 | class SMILESParserError(ValueError):
10 | """Exception raised when a SMILES fails to be parsed.
11 | """
12 |
13 | def __init__(self, smiles, reason="N/A", idx=-1):
14 | self.smiles = smiles
15 | self.idx = idx
16 | self.reason = reason
17 |
18 | def __str__(self):
19 | err_msg = "\n" \
20 | "\tSMILES: {smiles}\n" \
21 | "\t {pointer}\n" \
22 | "\tIndex: {index}\n" \
23 | "\tReason: {reason}"
24 |
25 | return err_msg.format(
26 | smiles=self.smiles,
27 | pointer=(" " * self.idx + "^"),
28 | index=self.idx,
29 | reason=self.reason
30 | )
31 |
32 |
33 | class EncoderError(Exception):
34 | """Exception raised by :func:`selfies.encoder`.
35 | """
36 |
37 | pass
38 |
39 |
40 | class DecoderError(Exception):
41 | """Exception raised by :func:`selfies.decoder`.
42 | """
43 |
44 | pass
45 |
--------------------------------------------------------------------------------
/mofreinforce/libs/selfies/grammar_rules.py:
--------------------------------------------------------------------------------
1 | # This code is adapted from the "selfies" repository:
2 | # https://github.com/aspuru-guzik-group/selfies.git
3 |
4 | # The code in this file is licensed under the Apache License, Version 2.0:
5 | # https://github.com/aspuru-guzik-group/selfies/blob/master/LICENSE
6 |
7 | # Adapted by Hyunsoo Park, 2022
8 |
9 | import functools
10 | import itertools
11 | import re
12 | from typing import Any, List, Optional, Tuple
13 |
14 | from .constants import (
15 | ELEMENTS,
16 | INDEX_ALPHABET,
17 | INDEX_CODE,
18 | ORGANIC_SUBSET
19 | )
20 | from .mol_graph import Atom
21 | from .utils.smiles_utils import smiles_to_bond
22 |
23 |
24 | def process_atom_symbol(symbol: str) -> Optional[Tuple[Any, Atom]]:
25 | try:
26 | output = _PROCESS_ATOM_CACHE[symbol]
27 | except KeyError:
28 | output = _process_atom_selfies_no_cache(symbol)
29 | if output is None:
30 | return None
31 | _PROCESS_ATOM_CACHE[symbol] = output
32 |
33 | bond_info, atom_fac = output
34 | atom = atom_fac()
35 | if atom.bonding_capacity < 0:
36 | return None # too many Hs (e.g. [CH9]
37 | return bond_info, atom
38 |
39 |
40 | def process_branch_symbol(symbol: str) -> Optional[Tuple[int, int]]:
41 | try:
42 | return _PROCESS_BRANCH_CACHE[symbol]
43 | except KeyError:
44 | return None
45 |
46 |
47 | def process_ring_symbol(symbol: str) -> Optional[Tuple[int, int, Any]]:
48 | try:
49 | return _PROCESS_RING_CACHE[symbol]
50 | except KeyError:
51 | return None
52 |
53 |
54 | def next_atom_state(
55 | bond_order: int, bond_cap: int, state: int
56 | ) -> Tuple[int, Optional[int]]:
57 | if state == 0:
58 | bond_order = 0
59 |
60 | bond_order = min(bond_order, state, bond_cap)
61 | bonds_left = bond_cap - bond_order
62 | next_state = None if (bonds_left == 0) else bonds_left
63 | return bond_order, next_state
64 |
65 |
66 | def next_branch_state(
67 | branch_type: int, state: int
68 | ) -> Tuple[int, Optional[int]]:
69 | assert 1 <= branch_type <= 3
70 | assert state > 1
71 |
72 | branch_init_state = min(state - 1, branch_type)
73 | next_state = state - branch_init_state
74 | return branch_init_state, next_state
75 |
76 |
77 | def next_ring_state(
78 | ring_type: int, state: int
79 | ) -> Tuple[int, Optional[int]]:
80 | assert state > 0
81 |
82 | bond_order = min(ring_type, state)
83 | bonds_left = state - bond_order
84 | next_state = None if (bonds_left == 0) else bonds_left
85 | return bond_order, next_state
86 |
87 |
88 | def get_index_from_selfies(*symbols: List[str]) -> int:
89 | index = 0
90 | for i, c in enumerate(reversed(symbols)):
91 | index += INDEX_CODE.get(c, 0) * (len(INDEX_CODE) ** i)
92 | return index
93 |
94 |
95 | def get_selfies_from_index(index: int) -> List[str]:
96 | if index < 0:
97 | raise IndexError()
98 | elif index == 0:
99 | return [INDEX_ALPHABET[0]]
100 |
101 | symbols = []
102 | base = len(INDEX_ALPHABET)
103 | while index:
104 | symbols.append(INDEX_ALPHABET[index % base])
105 | index //= base
106 | return symbols[::-1]
107 |
108 |
109 | # =============================================================================
110 | # Caches (for computational speed)
111 | # =============================================================================
112 |
113 |
114 | SELFIES_ATOM_PATTERN = re.compile(
115 | r"^[\[]" # opening square bracket [
116 | r"([=#/\\]?)" # bond char
117 | r"(\d*)" # isotope number (optional, e.g. 123, 26)
118 | r"([A-Z][a-z]?|\*)" # element symbol
119 | r"([@]{0,2})" # chiral_tag (optional, only @ and @@ supported)
120 | r"((?:[H]\d)?)" # H count (optional, e.g. H1, H3)
121 | r"((?:[+-][1-9]+)?)" # charge (optional, e.g. +1)
122 | r"[]]$" # closing square bracket ]
123 | )
124 |
125 |
126 | def _process_atom_selfies_no_cache(symbol):
127 | m = SELFIES_ATOM_PATTERN.match(symbol)
128 | if m is None:
129 | return None
130 | bond_char, isotope, element, chirality, h_count, charge = m.groups()
131 |
132 | if symbol[1 + len(bond_char):-1] in ORGANIC_SUBSET:
133 | atom_fac = functools.partial(Atom, element=element, is_aromatic=False)
134 | return smiles_to_bond(bond_char), atom_fac
135 |
136 | isotope = None if (isotope == "") else int(isotope)
137 | if element not in ELEMENTS:
138 | return None
139 | chirality = None if (chirality == "") else chirality
140 |
141 | s = h_count
142 | if s == "":
143 | h_count = 0
144 | else:
145 | h_count = int(s[1:])
146 |
147 | s = charge
148 | if s == "":
149 | charge = 0
150 | else:
151 | charge = int(s[1:])
152 | charge *= 1 if (s[0] == "+") else -1
153 |
154 | atom_fac = functools.partial(
155 | Atom,
156 | element=element,
157 | is_aromatic=False,
158 | isotope=isotope,
159 | chirality=chirality,
160 | h_count=h_count,
161 | charge=charge
162 | )
163 |
164 | return smiles_to_bond(bond_char), atom_fac
165 |
166 |
167 | def _build_atom_cache():
168 | cache = dict()
169 | common_symbols = [
170 | "[#C+1]", "[#C-1]", "[#C]", "[#N+1]", "[#N]", "[#O+1]", "[#P+1]",
171 | "[#P-1]", "[#P]", "[#S+1]", "[#S-1]", "[#S]", "[=C+1]", "[=C-1]",
172 | "[=C]", "[=N+1]", "[=N-1]", "[=N]", "[=O+1]", "[=O]", "[=P+1]",
173 | "[=P-1]", "[=P]", "[=S+1]", "[=S-1]", "[=S]", "[Br]", "[C+1]", "[C-1]",
174 | "[C]", "[Cl]", "[F]", "[H]", "[I]", "[N+1]", "[N-1]", "[N]", "[O+1]",
175 | "[O-1]", "[O]", "[P+1]", "[P-1]", "[P]", "[S+1]", "[S-1]", "[S]"
176 | ]
177 |
178 | for symbol in common_symbols:
179 | cache[symbol] = _process_atom_selfies_no_cache(symbol)
180 | return cache
181 |
182 |
183 | def _build_branch_cache():
184 | cache = dict()
185 | for L in range(1, 4):
186 | for bond_char in ["", "=", "#"]:
187 | symbol = "[{}Branch{}]".format(bond_char, L)
188 | cache[symbol] = (smiles_to_bond(bond_char)[0], L)
189 | return cache
190 |
191 |
192 | def _build_ring_cache():
193 | cache = dict()
194 | for L in range(1, 4):
195 | # [RingL], [=RingL], [#RingL]
196 | for bond_char in ["", "=", "#"]:
197 | symbol = "[{}Ring{}]".format(bond_char, L)
198 | order, stereo = smiles_to_bond(bond_char)
199 | cache[symbol] = (order, L, (stereo, stereo))
200 |
201 | # [-/RingL], [\/RingL], [\-RingL], ...
202 | for lchar, rchar in itertools.product(["-", "/", "\\"], repeat=2):
203 | if lchar == rchar == "-":
204 | continue
205 | symbol = "[{}{}Ring{}]".format(lchar, rchar, L)
206 | order, lstereo = smiles_to_bond(lchar)
207 | order, rstereo = smiles_to_bond(rchar)
208 | cache[symbol] = (order, L, (lstereo, rstereo))
209 | return cache
210 |
211 |
212 | _PROCESS_ATOM_CACHE = _build_atom_cache()
213 |
214 | _PROCESS_BRANCH_CACHE = _build_branch_cache()
215 |
216 | _PROCESS_RING_CACHE = _build_ring_cache()
217 |
--------------------------------------------------------------------------------
/mofreinforce/libs/selfies/mol_graph.py:
--------------------------------------------------------------------------------
1 | # This code is adapted from the "selfies" repository:
2 | # https://github.com/aspuru-guzik-group/selfies.git
3 |
4 | # The code in this file is licensed under the Apache License, Version 2.0:
5 | # https://github.com/aspuru-guzik-group/selfies/blob/master/LICENSE
6 |
7 | # Adapted by Hyunsoo Park, 2022
8 |
9 | import functools
10 | import itertools
11 | from typing import List, Optional, Union
12 | from dataclasses import dataclass, field
13 |
14 | from .bond_constraints import get_bonding_capacity
15 | from .constants import AROMATIC_VALENCES
16 | from .utils.matching_utils import find_perfect_matching
17 |
18 |
19 | @dataclass
20 | class Attribution:
21 | """A dataclass that contains token string and its index.
22 | """
23 | #: token index
24 | index: int
25 | #: token string
26 | token: str
27 |
28 |
29 | @dataclass
30 | class AttributionMap:
31 | """A mapping from input to single output token showing which
32 | input tokens created the output token.
33 | """
34 | #: Index of output token
35 | index: int
36 | #: Output token
37 | token: str
38 | #: List of input tokens that created the output token
39 | attribution: List[Attribution] = field(default_factory=list)
40 |
41 |
42 | class Atom:
43 | """An atom with associated specifications (e.g. charge, chirality).
44 | """
45 |
46 | def __init__(
47 | self,
48 | element: str,
49 | is_aromatic: bool,
50 | isotope: Optional[int] = None,
51 | chirality: Optional[str] = None,
52 | h_count: Optional[int] = None,
53 | charge: int = 0
54 | ):
55 | self.index = None
56 | self.element = element
57 | self.is_aromatic = is_aromatic
58 | self.isotope = isotope
59 | self.chirality = chirality
60 | self.h_count = h_count
61 | self.charge = charge
62 |
63 | @property
64 | @functools.lru_cache()
65 | def bonding_capacity(self):
66 | bond_cap = get_bonding_capacity(self.element, self.charge)
67 | bond_cap -= 0 if (self.h_count is None) else self.h_count
68 | return bond_cap
69 |
70 | def invert_chirality(self) -> None:
71 | if self.chirality == "@":
72 | self.chirality = "@@"
73 | elif self.chirality == "@@":
74 | self.chirality = "@"
75 |
76 |
77 | class DirectedBond:
78 | """A bond that contains directional information.
79 | """
80 |
81 | def __init__(
82 | self,
83 | src: int,
84 | dst: int,
85 | order: Union[int, float],
86 | stereo: Optional[str],
87 | ring_bond: bool
88 | ):
89 | self.src = src
90 | self.dst = dst
91 | self.order = order
92 | self.stereo = stereo
93 | self.ring_bond = ring_bond
94 |
95 |
96 | class MolecularGraph:
97 | """A molecular graph.
98 |
99 | Molecules can be viewed as weighted undirected graphs. However, SMILES
100 | and SELFIES strings are more naturally represented as weighted directed
101 | graphs, where the direction of the edges specifies the order of atoms
102 | and bonds in the string.
103 | """
104 |
105 | def __init__(self, attributable=False):
106 | self._roots = list() # stores root atoms, where traversal begins
107 | self._atoms = list() # stores atoms in this graph
108 | self._bond_dict = dict() # stores all bonds in this graph
109 | self._adj_list = list() # adjacency list, representing this graph
110 | self._bond_counts = list() # stores number of bonds an atom has made
111 | self._ring_bond_flags = list() # stores if an atom makes a ring bond
112 | self._delocal_subgraph = dict() # delocalization subgraph
113 | self._attribution = dict() # attribution of each atom/bond
114 | self._attributable = attributable
115 |
116 | def __len__(self):
117 | return len(self._atoms)
118 |
119 | def has_bond(self, a: int, b: int) -> bool:
120 | if a > b:
121 | a, b = b, a
122 | return (a, b) in self._bond_dict
123 |
124 | def has_out_ring_bond(self, src: int) -> bool:
125 | return self._ring_bond_flags[src]
126 |
127 | def get_attribution(
128 | self,
129 | o: Union[DirectedBond, Atom]
130 | ) -> List[Attribution]:
131 | if self._attributable and o in self._attribution:
132 | return self._attribution[o]
133 | return None
134 |
135 | def get_roots(self) -> List[int]:
136 | return self._roots
137 |
138 | def get_atom(self, idx: int) -> Atom:
139 | return self._atoms[idx]
140 |
141 | def get_atoms(self) -> List[Atom]:
142 | return self._atoms
143 |
144 | def get_dirbond(self, src, dst) -> DirectedBond:
145 | return self._bond_dict[(src, dst)]
146 |
147 | def get_out_dirbonds(self, src: int) -> List[DirectedBond]:
148 | return self._adj_list[src]
149 |
150 | def get_bond_count(self, idx: int) -> int:
151 | return self._bond_counts[idx]
152 |
153 | def add_atom(self, atom: Atom, mark_root: bool = False) -> Atom:
154 | atom.index = len(self)
155 |
156 | if mark_root:
157 | self._roots.append(atom.index)
158 | self._atoms.append(atom)
159 | self._adj_list.append(list())
160 | self._bond_counts.append(0)
161 | self._ring_bond_flags.append(False)
162 | if atom.is_aromatic:
163 | self._delocal_subgraph[atom.index] = list()
164 | return atom
165 |
166 | def add_attribution(
167 | self,
168 | o: Union[DirectedBond, Atom],
169 | attr: List[Attribution]
170 | ) -> None:
171 | if self._attributable:
172 | if o in self._attribution:
173 | self._attribution[o].extend(attr)
174 | else:
175 | self._attribution[o] = attr
176 |
177 | def add_bond(
178 | self, src: int, dst: int,
179 | order: Union[int, float], stereo: str
180 | ) -> DirectedBond:
181 | assert src < dst
182 |
183 | bond = DirectedBond(src, dst, order, stereo, False)
184 | self._add_bond_at_loc(bond, -1)
185 | self._bond_counts[src] += order
186 | self._bond_counts[dst] += order
187 |
188 | if order == 1.5:
189 | self._delocal_subgraph.setdefault(src, []).append(dst)
190 | self._delocal_subgraph.setdefault(dst, []).append(src)
191 | return bond
192 |
193 | def add_placeholder_bond(self, src: int) -> int:
194 | out_edges = self._adj_list[src]
195 | out_edges.append(None)
196 | return len(out_edges) - 1
197 |
198 | def add_ring_bond(
199 | self, a: int, b: int,
200 | order: Union[int, float],
201 | a_stereo: Optional[str], b_stereo: Optional[str],
202 | a_pos: int = -1, b_pos: int = -1
203 | ) -> None:
204 | a_bond = DirectedBond(a, b, order, a_stereo, True)
205 | b_bond = DirectedBond(b, a, order, b_stereo, True)
206 | self._add_bond_at_loc(a_bond, a_pos)
207 | self._add_bond_at_loc(b_bond, b_pos)
208 | self._bond_counts[a] += order
209 | self._bond_counts[b] += order
210 | self._ring_bond_flags[a] = True
211 | self._ring_bond_flags[b] = True
212 |
213 | if order == 1.5:
214 | self._delocal_subgraph.setdefault(a, []).append(b)
215 | self._delocal_subgraph.setdefault(b, []).append(a)
216 |
217 | def update_bond_order(
218 | self, a: int, b: int,
219 | new_order: Union[int, float]
220 | ) -> None:
221 | assert 1 <= new_order <= 3
222 |
223 | if a > b:
224 | a, b = b, a # swap so that a < b
225 | a_to_b = self._bond_dict[(a, b)] # prev step guarantees existence
226 | if new_order == a_to_b.order:
227 | return
228 | elif a_to_b.ring_bond:
229 | b_to_a = self._bond_dict[(b, a)]
230 | bonds = (a_to_b, b_to_a)
231 | else:
232 | bonds = (a_to_b,)
233 |
234 | old_order = bonds[0].order
235 | for bond in bonds:
236 | bond.order = new_order
237 | self._bond_counts[a] += (new_order - old_order)
238 | self._bond_counts[b] += (new_order - old_order)
239 |
240 | def _add_bond_at_loc(self, bond, pos):
241 | self._bond_dict[(bond.src, bond.dst)] = bond
242 |
243 | out_edges = self._adj_list[bond.src]
244 | if (pos == -1) or (pos == len(out_edges)):
245 | out_edges.append(bond)
246 | elif out_edges[pos] is None:
247 | out_edges[pos] = bond
248 | else:
249 | out_edges.insert(pos, bond)
250 |
251 | def is_kekulized(self) -> bool:
252 | return not self._delocal_subgraph
253 |
254 | def kekulize(self) -> bool:
255 | # Algorithm based on Depth-First article by Richard L. Apodaca
256 | # Reference:
257 | # https://depth-first.com/articles/2020/02/10/
258 | # a-comprehensive-treatment-of-aromaticity-in-the-smiles-language/
259 |
260 | if self.is_kekulized():
261 | return True
262 |
263 | ds = self._delocal_subgraph
264 | kept_nodes = set(itertools.filterfalse(self._prune_from_ds, ds))
265 |
266 | # relabel kept DS nodes to be 0, 1, 2, ...
267 | label_to_node = list(sorted(kept_nodes))
268 | node_to_label = {v: i for i, v in enumerate(label_to_node)}
269 |
270 | # pruned and relabelled DS
271 | pruned_ds = [list() for _ in range(len(kept_nodes))]
272 | for node in kept_nodes:
273 | label = node_to_label[node]
274 | for adj in filter(lambda v: v in kept_nodes, ds[node]):
275 | pruned_ds[label].append(node_to_label[adj])
276 |
277 | matching = find_perfect_matching(pruned_ds)
278 | if matching is None:
279 | return False
280 |
281 | # de-aromatize and then make double bonds
282 | for node in ds:
283 | for adj in ds[node]:
284 | self.update_bond_order(node, adj, new_order=1)
285 | self._atoms[node].is_aromatic = False
286 | self._bond_counts[node] = int(self._bond_counts[node])
287 |
288 | for matched_labels in enumerate(matching):
289 | matched_nodes = tuple(label_to_node[i] for i in matched_labels)
290 | self.update_bond_order(*matched_nodes, new_order=2)
291 |
292 | self._delocal_subgraph = dict() # clear DS
293 | return True
294 |
295 | def _prune_from_ds(self, node):
296 | adj_nodes = self._delocal_subgraph[node]
297 | if not adj_nodes:
298 | return True # aromatic atom with no aromatic bonds
299 |
300 | atom = self._atoms[node]
301 | valences = AROMATIC_VALENCES[atom.element]
302 |
303 | # each bond in DS has order 1.5 - we treat them as single bonds
304 | used_electrons = int(self._bond_counts[node] - 0.5 * len(adj_nodes))
305 |
306 | if atom.h_count is None: # account for implicit Hs
307 | assert atom.charge == 0
308 | return any(used_electrons == v for v in valences)
309 | else:
310 | valence = valences[-1] - atom.charge
311 | used_electrons += atom.h_count
312 | free_electrons = valence - used_electrons
313 | return not ((free_electrons >= 0) and (free_electrons % 2 != 0))
314 |
--------------------------------------------------------------------------------
/mofreinforce/libs/selfies/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/mofreinforce/libs/selfies/utils/encoding_utils.py:
--------------------------------------------------------------------------------
1 | # This code is adapted from the "selfies" repository:
2 | # https://github.com/aspuru-guzik-group/selfies.git
3 |
4 | # The code in this file is licensed under the Apache License, Version 2.0:
5 | # https://github.com/aspuru-guzik-group/selfies/blob/master/LICENSE
6 |
7 | # Adapted by Hyunsoo Park, 2022
8 |
9 | from typing import Dict, List, Tuple, Union
10 |
11 | from .selfies_utils import len_selfies, split_selfies
12 |
13 |
14 | def selfies_to_encoding(
15 | selfies: str,
16 | vocab_stoi: Dict[str, int],
17 | pad_to_len: int = -1,
18 | enc_type: str = 'both'
19 | ) -> Union[List[int], List[List[int]], Tuple[List[int], List[List[int]]]]:
20 | """Converts a SELFIES string into its label (integer)
21 | and/or one-hot encoding.
22 |
23 | A label encoded output will be a list of shape ``(L,)`` and a
24 | one-hot encoded output will be a 2D list of shape ``(L, len(vocab_stoi))``,
25 | where ``L`` is the symbol length of the SELFIES string. Optionally,
26 | the SELFIES string can be padded before it is encoded.
27 |
28 | :param selfies: the SELFIES string to be encoded.
29 | :param vocab_stoi: a dictionary that maps SELFIES symbols to indices,
30 | which must be non-negative and contiguous, starting from 0.
31 | If the SELFIES string is to be padded, then the special padding symbol
32 | ``[nop]`` must also be a key in this dictionary.
33 | :param pad_to_len: the length that the SELFIES string string is padded to.
34 | If this value is less than or equal to the symbol length of the
35 | SELFIES string, then no padding is added. Defaults to ``-1``.
36 | :param enc_type: the type of encoding of the output:
37 | ``label`` or ``one_hot`` or ``both``.
38 | If this value is ``both``, then a tuple of the label and one-hot
39 | encodings is returned. Defaults to ``both``.
40 | :return: the label encoded and/or one-hot encoded SELFIES string.
41 |
42 | :Example:
43 |
44 | >>> import selfies as sf
45 | >>> sf.selfies_to_encoding("[C][F]", {"[C]": 0, "[F]": 1})
46 | ([0, 1], [[1, 0], [0, 1]])
47 | """
48 |
49 | # some error checking
50 | if enc_type not in ("label", "one_hot", "both"):
51 | raise ValueError("enc_type must be in ('label', 'one_hot', 'both')")
52 |
53 | # pad with [nop]
54 | if pad_to_len > len_selfies(selfies):
55 | selfies += "[nop]" * (pad_to_len - len_selfies(selfies))
56 |
57 | # integer encode
58 | char_list = split_selfies(selfies)
59 | integer_encoded = [vocab_stoi[char] for char in char_list]
60 |
61 | if enc_type == "label":
62 | return integer_encoded
63 |
64 | # one-hot encode
65 | one_hot_encoded = list()
66 | for index in integer_encoded:
67 | letter = [0] * len(vocab_stoi)
68 | letter[index] = 1
69 | one_hot_encoded.append(letter)
70 |
71 | if enc_type == "one_hot":
72 | return one_hot_encoded
73 | return integer_encoded, one_hot_encoded
74 |
75 |
76 | def encoding_to_selfies(
77 | encoding: Union[List[int], List[List[int]]],
78 | vocab_itos: Dict[int, str],
79 | enc_type: str,
80 | ) -> str:
81 | """Converts a label (integer) or one-hot encoding into a SELFIES string.
82 |
83 | If the input is label encoded, then a list of shape ``(L,)`` is
84 | expected; and if the input is one-hot encoded, then a 2D list of
85 | shape ``(L, len(vocab_itos))`` is expected.
86 |
87 | :param encoding: a label or one-hot encoding.
88 | :param vocab_itos: a dictionary that maps indices to SELFIES symbols.
89 | The indices of this dictionary must be non-negative and contiguous,
90 | starting from 0.
91 | :param enc_type: the type of encoding of the input:
92 | ``label`` or ``one_hot``.
93 | :return: the SELFIES string represented by the input encoding.
94 |
95 | :Example:
96 |
97 | >>> import selfies as sf
98 | >>> one_hot = [[0, 1, 0], [0, 0, 1], [1, 0, 0]]
99 | >>> vocab_itos = {0: "[nop]", 1: "[C]", 2: "[F]"}
100 | >>> sf.encoding_to_selfies(one_hot, vocab_itos, enc_type="one_hot")
101 | '[C][F][nop]'
102 | """
103 |
104 | if enc_type not in ("label", "one_hot"):
105 | raise ValueError("enc_type must be in ('label', 'one_hot')")
106 |
107 | if enc_type == "one_hot": # Get integer encoding
108 | integer_encoded = []
109 | for row in encoding:
110 | integer_encoded.append(row.index(1))
111 | else:
112 | integer_encoded = encoding
113 |
114 | # Integer encoding -> SELFIES
115 | char_list = [vocab_itos[i] for i in integer_encoded]
116 | selfies = "".join(char_list)
117 |
118 | return selfies
119 |
120 |
121 | def batch_selfies_to_flat_hot(
122 | selfies_batch: List[str],
123 | vocab_stoi: Dict[str, int],
124 | pad_to_len: int = -1,
125 | ) -> List[List[int]]:
126 | """Converts a list of SELFIES strings into its list of flattened
127 | one-hot encodings.
128 |
129 | Each SELFIES string in the input list is one-hot encoded
130 | (and then flattened) using :func:`selfies.selfies_to_encoding`, with
131 | ``vocab_stoi`` and ``pad_to_len`` being passed in as arguments.
132 |
133 | :param selfies_batch: the list of SELFIES strings to be encoded.
134 | :param vocab_stoi: a dictionary that maps SELFIES symbols to indices.
135 | :param pad_to_len: the length that each SELFIES string in the input list
136 | is padded to. Defaults to ``-1``.
137 | :return: the flattened one-hot encodings of the input list.
138 |
139 | :Example:
140 |
141 | >>> import selfies as sf
142 | >>> batch = ["[C]", "[C][C]"]
143 | >>> vocab_stoi = {"[nop]": 0, "[C]": 1}
144 | >>> sf.batch_selfies_to_flat_hot(batch, vocab_stoi, 2)
145 | [[0, 1, 1, 0], [0, 1, 0, 1]]
146 | """
147 |
148 | hot_list = list()
149 |
150 | for selfies in selfies_batch:
151 | one_hot = selfies_to_encoding(selfies, vocab_stoi, pad_to_len,
152 | enc_type="one_hot")
153 | flattened = [elem for vec in one_hot for elem in vec]
154 | hot_list.append(flattened)
155 |
156 | return hot_list
157 |
158 |
159 | def batch_flat_hot_to_selfies(
160 | one_hot_batch: List[List[int]],
161 | vocab_itos: Dict[int, str],
162 | ) -> List[str]:
163 | """Converts a list of flattened one-hot encodings into a list
164 | of SELFIES strings.
165 |
166 | Each encoding in the input list is unflattened and then decoded using
167 | :func:`selfies.encoding_to_selfies`, with ``vocab_itos`` being passed in
168 | as an argument.
169 |
170 | :param one_hot_batch: a list of flattened one-hot encodings. Each
171 | encoding must be a list of length divisible by ``len(vocab_itos)``.
172 | :param vocab_itos: a dictionary that maps indices to SELFIES symbols.
173 | :return: the list of SELFIES strings represented by the input encodings.
174 |
175 | :Example:
176 |
177 | >>> import selfies as sf
178 | >>> batch = [[0, 1, 1, 0], [0, 1, 0, 1]]
179 | >>> vocab_itos = {0: "[nop]", 1: "[C]"}
180 | >>> sf.batch_flat_hot_to_selfies(batch, vocab_itos)
181 | ['[C][nop]', '[C][C]']
182 | """
183 |
184 | selfies_list = []
185 |
186 | for flat_one_hot in one_hot_batch:
187 |
188 | # Reshape to an L x M array where each column represents an alphabet
189 | # entry and each row is a position in the selfies
190 | one_hot = []
191 |
192 | M = len(vocab_itos)
193 | if len(flat_one_hot) % M != 0:
194 | raise ValueError("size of vector in one_hot_batch not divisible "
195 | "by the length of the vocabulary.")
196 | L = len(flat_one_hot) // M
197 |
198 | for i in range(L):
199 | one_hot.append(flat_one_hot[M * i: M * (i + 1)])
200 |
201 | selfies = encoding_to_selfies(one_hot, vocab_itos, enc_type="one_hot")
202 | selfies_list.append(selfies)
203 |
204 | return selfies_list
205 |
--------------------------------------------------------------------------------
/mofreinforce/libs/selfies/utils/matching_utils.py:
--------------------------------------------------------------------------------
1 | # This code is adapted from the "selfies" repository:
2 | # https://github.com/aspuru-guzik-group/selfies.git
3 |
4 | # The code in this file is licensed under the Apache License, Version 2.0:
5 | # https://github.com/aspuru-guzik-group/selfies/blob/master/LICENSE
6 |
7 | # Adapted by Hyunsoo Park, 2022
8 |
9 | import heapq
10 | import itertools
11 | from collections import deque
12 | from typing import List, Optional
13 |
14 |
15 | def find_perfect_matching(graph: List[List[int]]) -> Optional[List[int]]:
16 | """Finds a perfect matching for an undirected graph (without self-loops).
17 |
18 | :param graph: an adjacency list representing the input graph.
19 | :return: a list representing a perfect matching, where j is the i-th
20 | element if nodes i and j are matched. Returns None, if the graph cannot
21 | be perfectly matched.
22 | """
23 |
24 | # start with a maximal matching for efficiency
25 | matching = _greedy_matching(graph)
26 |
27 | unmatched = set(i for i in range(len(graph)) if matching[i] is None)
28 | while unmatched:
29 |
30 | # find augmenting path which starts at root
31 | root = unmatched.pop()
32 | path = _find_augmenting_path(graph, root, matching)
33 |
34 | if path is None:
35 | return None
36 | else:
37 | _flip_augmenting_path(matching, path)
38 | unmatched.discard(path[0])
39 | unmatched.discard(path[-1])
40 |
41 | return matching
42 |
43 |
44 | def _greedy_matching(graph):
45 | matching = [None] * len(graph)
46 | free_degrees = [len(graph[i]) for i in range(len(graph))]
47 | # free_degrees[i] = number of unmatched neighbors for node i
48 |
49 | # prioritize nodes with fewer unmatched neighbors
50 | node_pqueue = [(free_degrees[i], i) for i in range(len(graph))]
51 | heapq.heapify(node_pqueue)
52 |
53 | while node_pqueue:
54 | _, node = heapq.heappop(node_pqueue)
55 |
56 | if (matching[node] is not None) or (free_degrees[node] == 0):
57 | continue # node cannot be matched
58 |
59 | # match node with first unmatched neighbor
60 | mate = next(i for i in graph[node] if matching[i] is None)
61 | matching[node] = mate
62 | matching[mate] = node
63 |
64 | for adj in itertools.chain(graph[node], graph[mate]):
65 | free_degrees[adj] -= 1
66 | if (matching[adj] is None) and (free_degrees[adj] > 0):
67 | heapq.heappush(node_pqueue, (free_degrees[adj], adj))
68 |
69 | return matching
70 |
71 |
72 | def _find_augmenting_path(graph, root, matching):
73 | assert matching[root] is None
74 |
75 | # run modified BFS to find path from root to unmatched node
76 | other_end = None
77 | node_queue = deque([root])
78 |
79 | # parent BFS tree - None indicates an unvisited node
80 | parents = [None] * len(graph)
81 | parents[root] = [None, None]
82 |
83 | while node_queue:
84 | node = node_queue.popleft()
85 |
86 | for adj in graph[node]:
87 | if matching[adj] is None: # unmatched node
88 | if adj != root: # augmenting path found!
89 | parents[adj] = [node, adj]
90 | other_end = adj
91 | break
92 | else:
93 | adj_mate = matching[adj]
94 | if parents[adj_mate] is None: # adj_mate not visited
95 | parents[adj_mate] = [node, adj]
96 | node_queue.append(adj_mate)
97 |
98 | if other_end is not None:
99 | break # augmenting path found!
100 |
101 | if other_end is None:
102 | return None
103 | else:
104 | path = []
105 | node = other_end
106 | while node != root:
107 | path.append(parents[node][1])
108 | path.append(parents[node][0])
109 | node = parents[node][0]
110 | return path
111 |
112 |
113 | def _flip_augmenting_path(matching, path):
114 | for i in range(0, len(path), 2):
115 | a, b = path[i], path[i + 1]
116 | matching[a] = b
117 | matching[b] = a
118 |
--------------------------------------------------------------------------------
/mofreinforce/libs/selfies/utils/selfies_utils.py:
--------------------------------------------------------------------------------
1 | # This code is adapted from the "selfies" repository:
2 | # https://github.com/aspuru-guzik-group/selfies.git
3 |
4 | # The code in this file is licensed under the Apache License, Version 2.0:
5 | # https://github.com/aspuru-guzik-group/selfies/blob/master/LICENSE
6 |
7 | # Adapted by Hyunsoo Park, 2022
8 |
9 | from typing import Iterable, Iterator, Set
10 |
11 |
12 | def len_selfies(selfies: str) -> int:
13 | """Returns the number of symbols in a given SELFIES string.
14 |
15 | :param selfies: a SELFIES string.
16 | :return: the symbol length of the SELFIES string.
17 |
18 | :Example:
19 |
20 | >>> import selfies as sf
21 | >>> sf.len_selfies("[C][=C][F].[C]")
22 | 5
23 | """
24 |
25 | return selfies.count("[") + selfies.count(".")
26 |
27 |
28 | def split_selfies(selfies: str) -> Iterator[str]:
29 | """Tokenizes a SELFIES string into its individual symbols.
30 |
31 | :param selfies: a SELFIES string.
32 | :return: the symbols of the SELFIES string one-by-one with order preserved.
33 |
34 | :Example:
35 |
36 | >>> import selfies as sf
37 | >>> list(sf.split_selfies("[C][=C][F].[C]"))
38 | ['[C]', '[=C]', '[F]', '.', '[C]']
39 | """
40 |
41 | left_idx = selfies.find("[")
42 |
43 | while 0 <= left_idx < len(selfies):
44 | right_idx = selfies.find("]", left_idx + 1)
45 | if right_idx == -1:
46 | raise ValueError("malformed SELFIES string, hanging '[' bracket")
47 |
48 | next_symbol = selfies[left_idx: right_idx + 1]
49 | yield next_symbol
50 |
51 | left_idx = right_idx + 1
52 | if selfies[left_idx: left_idx + 1] == ".":
53 | yield "."
54 | left_idx += 1
55 |
56 |
57 | def get_alphabet_from_selfies(selfies_iter: Iterable[str]) -> Set[str]:
58 | """Constructs an alphabet from an iterable of SELFIES strings.
59 |
60 | The returned alphabet is the set of all symbols that appear in the
61 | SELFIES strings from the input iterable, minus the dot ``.`` symbol.
62 |
63 | :param selfies_iter: an iterable of SELFIES strings.
64 | :return: an alphabet of SELFIES symbols, built from the input iterable.
65 |
66 | :Example:
67 |
68 | >>> import selfies as sf
69 | >>> selfies_list = ["[C][F][O]", "[C].[O]", "[F][F]"]
70 | >>> alphabet = sf.get_alphabet_from_selfies(selfies_list)
71 | >>> sorted(list(alphabet))
72 | ['[C]', '[F]', '[O]']
73 | """
74 |
75 | alphabet = set()
76 | for s in selfies_iter:
77 | for symbol in split_selfies(s):
78 | alphabet.add(symbol)
79 | alphabet.discard(".")
80 | return alphabet
81 |
--------------------------------------------------------------------------------
/mofreinforce/predictor/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hspark1212/MOFreinforce/b920084a29db9223482ddfeffd2e775b82f49e63/mofreinforce/predictor/__init__.py
--------------------------------------------------------------------------------
/mofreinforce/predictor/baseline_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
4 |
5 |
6 | class BaselineModel(nn.Module):
7 | def __init__(self, vocab_dim, mc_dim, topo_dim, embed_dim, hid_dim):
8 | super(BaselineModel, self).__init__()
9 | # model
10 | # mc
11 | """
12 | self.embedding_mc = nn.Embedding(self.mc_dim, self.embed_dim)
13 | self.rc_mc = nn.Sequential(
14 | nn.Linear(self.embed_dim, 1),
15 | )
16 | """
17 | self.rc_mc = nn.Linear(mc_dim, 1)
18 | # topo
19 | self.embedding_topo = nn.Embedding(topo_dim, embed_dim)
20 | self.rc_topo = nn.Sequential(
21 | nn.Linear(embed_dim, 1),
22 | )
23 | # ol
24 | self.embedding_ol = nn.Embedding(vocab_dim, embed_dim)
25 | self.rnn = nn.RNN(input_size=embed_dim, hidden_size=hid_dim, num_layers=1)
26 | self.rc_ol = nn.Linear(hid_dim, 1)
27 | # total
28 | self.rc_total = nn.Linear(3, 1)
29 | self.sigmoid = nn.Sigmoid()
30 |
31 | def forward(self, batch):
32 | mc = batch["mc"]
33 | topo = batch["topo"]
34 | ol_pad = batch["ol_pad"]
35 | ol_len = batch["ol_len"]
36 | # mc
37 | # logit_mc = self.embedding_mc(mc) # [B, embed_dim]
38 | # logit_mc = self.rc_mc(logit_mc) # [B, 1]
39 | logit_mc = self.rc_mc(mc) # [B, 1]
40 |
41 | # topo
42 | logit_topo = self.embedding_topo(topo) # [B, embed_dim]
43 | logit_topo = self.rc_topo(logit_topo) # [B, 1]
44 | # ol
45 | logit_ol = self.embedding_ol(ol_pad) # [B, pad_len, embed_dim]
46 | packed_ol = pack_padded_sequence(
47 | logit_ol, ol_len, batch_first=True, enforce_sorted=False
48 | )
49 | output_packed, hidden = self.rnn(packed_ol) # [B, pad_len, hid_dim]
50 | output_ol, len_ol = pad_packed_sequence(output_packed, batch_first=True)
51 | logit_ol = self.rc_ol(output_ol[:, -1, :]) # [B, pad_len, hid_dim] ->[B, 1]
52 |
53 | # total
54 | logit_total = torch.cat([logit_topo, logit_mc, logit_ol], dim=-1)
55 | logit_total = self.rc_total(logit_total)
56 | logit_total = self.sigmoid(logit_total)
57 |
58 | return logit_total
59 |
--------------------------------------------------------------------------------
/mofreinforce/predictor/config_predictor.py:
--------------------------------------------------------------------------------
1 | import json
2 | from sacred import Experiment
3 |
4 | ex = Experiment("predictor")
5 |
6 | mc_to_idx = json.load(open("data/mc_to_idx.json"))
7 | topo_to_idx = json.load(open("data/topo_to_idx.json"))
8 | vocab_to_idx = json.load(open("data/vocab_to_idx.json")) # vocab for selfies
9 |
10 |
11 | def _loss_names(d):
12 | ret = {
13 | "regression": 0, # regression
14 | }
15 | ret.update(d)
16 | return ret
17 |
18 |
19 | @ex.config
20 | def config():
21 | seed = 0
22 | exp_name = "predictor"
23 |
24 | dataset_dir = "###"
25 | loss_names = _loss_names({"regression": 1})
26 |
27 | # model setting
28 | max_len = 128 # cls + mc + topo + ol_len
29 | vocab_dim = len(vocab_to_idx)
30 | mc_dim = len(mc_to_idx)
31 | topo_dim = len(topo_to_idx)
32 | weight_loss = None
33 |
34 | # transformer setting
35 | hid_dim = 256
36 | num_heads = 4 # hid_dim / 64
37 | num_layers = 4
38 | mlp_ratio = 4
39 | drop_rate = 0.1
40 |
41 | # run setting
42 | batch_size = 128
43 | per_gpu_batchsize = 64
44 | load_path = ""
45 | log_dir = "predictor/logs"
46 | num_workers = 8 # recommend num_gpus * 4
47 | num_nodes = 1
48 | devices = 2
49 | precision = 16
50 |
51 | # downstream
52 | downstream = ""
53 | n_classes = 0
54 | threshold_classification = None
55 |
56 | # Optimizer Setting
57 | optim_type = "adamw" # adamw, adam, sgd (momentum=0.9)
58 | learning_rate = 1e-4
59 | weight_decay = 1e-2
60 | decay_power = (
61 | 1 # default polynomial decay, [cosine, constant, constant_with_warmup]
62 | )
63 | max_epochs = 50
64 | max_steps = -1 # num_data * max_epoch // batch_size (accumulate_grad_batches)
65 | warmup_steps = 0.05 # int or float ( max_steps * warmup_steps)
66 | end_lr = 0
67 | lr_mult = 1 # multiply lr for downstream heads
68 |
69 | # trainer setting
70 | resume_from = None
71 | val_check_interval = 1.0
72 | test_only = False
73 |
74 | # normalize (when regression)
75 | mean = None
76 | std = None
77 |
78 |
79 | @ex.named_config
80 | def env_ifactor():
81 | pass
82 |
83 |
84 | @ex.named_config
85 | def regression_qkh_round1():
86 | exp_name = "regression_qkh_round1"
87 | dataset_dir = "data/dataset_predictor/qkh/round1"
88 |
89 | # trainer
90 | max_epochs = 50
91 | batch_size = 64
92 | per_gpu_batchsize = 16
93 |
94 | # normalize (when regression)
95 | mean = -20.331
96 | std = -10.383
97 |
98 |
99 | @ex.named_config
100 | def regression_qkh_round2():
101 | exp_name = "regression_qkh_round2"
102 | dataset_dir = "data/dataset_predictor/qkh/round2"
103 |
104 | # trainer
105 | max_epochs = 50
106 | batch_size = 64
107 | per_gpu_batchsize = 2
108 |
109 | # normalize (when regression)
110 | mean = -21.068
111 | std = 10.950
112 |
113 |
114 | @ex.named_config
115 | def regression_qkh_round3():
116 | exp_name = "regression_qkh_round3"
117 | dataset_dir = "data/dataset_predictor/qkh/round3"
118 |
119 | # trainer
120 | max_epochs = 50
121 | batch_size = 64
122 | per_gpu_batchsize = 2
123 |
124 | # normalize (when regression)
125 | mean = -21.810
126 | std = 11.452
127 |
128 |
129 | """
130 | v1_selectivity
131 | """
132 |
133 |
134 | @ex.named_config
135 | def regression_selectivity_round1():
136 | exp_name = "regression_selectivity_round1"
137 | dataset_dir = "data/dataset_predictor/selectivity/round1"
138 |
139 | # trainer
140 | max_epochs = 50
141 | batch_size = 128
142 | per_gpu_batchsize = 16
143 |
144 | # normalize (when regression)
145 | mean = 1.872
146 | std = 1.922
147 |
148 |
149 | @ex.named_config
150 | def regression_selectivity_round2():
151 | exp_name = "regression_selectivity_round2"
152 | dataset_dir = "data/dataset_predictor/selectivity/round2"
153 |
154 | # trainer
155 | max_epochs = 50
156 | batch_size = 128
157 | per_gpu_batchsize = 16
158 |
159 | # normalize (when regression)
160 | mean = 2.085
161 | std = 2.052
162 |
163 |
164 | @ex.named_config
165 | def regression_selectivity_round3():
166 | exp_name = "regression_selectivity_round3"
167 | dataset_dir = "data/dataset_predictor/selectivity/round3"
168 |
169 | # trainer
170 | max_epochs = 50
171 | batch_size = 128
172 | per_gpu_batchsize = 16
173 |
174 | # normalize (when regression)
175 | mean = 2.258
176 | std = 2.128
177 |
--------------------------------------------------------------------------------
/mofreinforce/predictor/config_predictor_ex.py:
--------------------------------------------------------------------------------
1 | import json
2 | from sacred import Experiment
3 |
4 | ex = Experiment("predictor")
5 |
6 | mc_to_idx = json.load(open("data/mc_to_idx.json"))
7 | topo_to_idx = json.load(open("data/topo_to_idx.json"))
8 | vocab_to_idx = json.load(open("data/vocab_to_idx.json")) # vocab for selfies
9 |
10 |
11 | def _loss_names(d):
12 | ret = {
13 | "regression": 0, # regression
14 | }
15 | ret.update(d)
16 | return ret
17 |
18 |
19 | @ex.config
20 | def config():
21 | seed = 0
22 | exp_name = "predictor"
23 |
24 | dataset_dir = "###"
25 | loss_names = _loss_names({"regression": 1})
26 |
27 | # model setting
28 | max_len = 128 # cls + mc + topo + ol_len
29 | vocab_dim = len(vocab_to_idx)
30 | mc_dim = len(mc_to_idx)
31 | topo_dim = len(topo_to_idx)
32 | weight_loss = None
33 |
34 | # transformer setting
35 | hid_dim = 256
36 | num_heads = 4 # hid_dim / 64
37 | num_layers = 4
38 | mlp_ratio = 4
39 | drop_rate = 0.1
40 |
41 | # run setting
42 | batch_size = 128
43 | per_gpu_batchsize = 64
44 | load_path = ""
45 | log_dir = "predictor/logs"
46 | num_workers = 8 # recommend num_gpus * 4
47 | num_nodes = 1
48 | devices = 2
49 | precision = 16
50 |
51 | # downstream
52 | downstream = ""
53 | n_classes = 0
54 | threshold_classification = None
55 |
56 | # Optimizer Setting
57 | optim_type = "adamw" # adamw, adam, sgd (momentum=0.9)
58 | learning_rate = 1e-4
59 | weight_decay = 1e-2
60 | decay_power = (
61 | 1 # default polynomial decay, [cosine, constant, constant_with_warmup]
62 | )
63 | max_epochs = 50
64 | max_steps = -1 # num_data * max_epoch // batch_size (accumulate_grad_batches)
65 | warmup_steps = 0.05 # int or float ( max_steps * warmup_steps)
66 | end_lr = 0
67 | lr_mult = 1 # multiply lr for downstream heads
68 |
69 | # trainer setting
70 | resume_from = None
71 | val_check_interval = 1.0
72 | test_only = False
73 |
74 | # normalize (when regression)
75 | mean = None
76 | std = None
77 |
78 |
79 | @ex.named_config
80 | def env_ifactor():
81 | pass
82 |
83 |
84 | @ex.named_config
85 | def regression_vf():
86 | exp_name = "regression_vf"
87 | dataset_dir = "data/dataset_predictor/vf"
88 |
89 | # trainer
90 | max_epochs = 50
91 | batch_size = 128
92 | per_gpu_batchsize = 16
93 |
94 |
95 | @ex.named_config
96 | def regression_qkh_old_round1():
97 | exp_name = "regression_qkh_old_round1"
98 | dataset_dir = "data/dataset_predictor/qkh/old_round1"
99 |
100 | # trainer
101 | max_epochs = 50
102 | batch_size = 64
103 | per_gpu_batchsize = 16
104 |
105 | # normalize (when regression)
106 | mean = -19.408
107 | std = -9.172
108 |
109 |
110 | @ex.named_config
111 | def regression_qkh_old_round2():
112 | exp_name = "regression_qkh_old_round2"
113 | dataset_dir = "data/dataset_predictor/qkh/old_round2"
114 |
115 | # trainer
116 | max_epochs = 50
117 | batch_size = 64
118 | per_gpu_batchsize = 16
119 |
120 | # normalize (when regression)
121 | mean = -19.886
122 | std = -9.811
123 |
124 |
125 | @ex.named_config
126 | def regression_qkh_new_round1():
127 | exp_name = "regression_qkh_new_round1"
128 | dataset_dir = "data/dataset_predictor/qkh/new_round1"
129 |
130 | # trainer
131 | max_epochs = 50
132 | batch_size = 64
133 | per_gpu_batchsize = 16
134 |
135 | # normalize (when regression)
136 | mean = -20.331
137 | std = -10.383
138 |
139 |
140 | @ex.named_config
141 | def regression_qkh_new_round2():
142 | exp_name = "regression_qkh_new_round2"
143 | dataset_dir = "data/dataset_predictor/qkh/new_round2"
144 |
145 | # trainer
146 | max_epochs = 50
147 | batch_size = 64
148 | per_gpu_batchsize = 2
149 |
150 | # normalize (when regression)
151 | mean = -21.068
152 | std = 10.950
153 |
154 |
155 | @ex.named_config
156 | def regression_qkh_new_round3():
157 | exp_name = "regression_qkh_new_round3"
158 | dataset_dir = "data/dataset_predictor/qkh/new_round3"
159 |
160 | # trainer
161 | max_epochs = 50
162 | batch_size = 64
163 | per_gpu_batchsize = 2
164 |
165 | # normalize (when regression)
166 | mean = -21.810
167 | std = 11.452
168 |
169 |
170 | """
171 | v1_selectivity
172 | """
173 |
174 |
175 | @ex.named_config
176 | def regression_selectivity_new_round1():
177 | exp_name = "regression_selectivity_new_round1"
178 | dataset_dir = "data/dataset_predictor/selectivity/new_round1"
179 |
180 | # trainer
181 | max_epochs = 50
182 | batch_size = 128
183 | per_gpu_batchsize = 16
184 |
185 | # normalize (when regression)
186 | mean = 1.872
187 | std = 1.922
188 |
189 |
190 | @ex.named_config
191 | def regression_selectivity_new_round2():
192 | exp_name = "regression_selectivity_new_round2"
193 | dataset_dir = "data/dataset_predictor/selectivity/new_round2"
194 |
195 | # trainer
196 | max_epochs = 50
197 | batch_size = 128
198 | per_gpu_batchsize = 16
199 |
200 | # normalize (when regression)
201 | mean = 2.085
202 | std = 2.052
203 |
204 |
205 | @ex.named_config
206 | def regression_selectivity_new_round3():
207 | exp_name = "regression_selectivity_new_round3"
208 | dataset_dir = "data/dataset_predictor/selectivity/new_round3"
209 |
210 | # trainer
211 | max_epochs = 50
212 | batch_size = 128
213 | per_gpu_batchsize = 16
214 |
215 | # normalize (when regression)
216 | mean = 2.258
217 | std = 2.128
218 |
--------------------------------------------------------------------------------
/mofreinforce/predictor/datamodule.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 | from functools import partial
3 |
4 | from pytorch_lightning import LightningDataModule
5 | from torch.utils.data import DataLoader
6 | from predictor.dataset import MOFDataset
7 |
8 |
9 | class Datamodule(LightningDataModule):
10 | def __init__(self, _config):
11 | super().__init__()
12 | self.dataset_dir = _config["dataset_dir"]
13 | self.batch_size = _config["batch_size"]
14 | self.num_workers = _config["num_workers"]
15 | self.max_len = _config["max_len"]
16 | self.tasks = [k for k, v in _config["loss_names"].items() if v >= 1]
17 |
18 | @property
19 | def dataset_cls(self):
20 | return MOFDataset
21 |
22 | def set_train_dataset(self):
23 | self.train_dataset = self.dataset_cls(
24 | dataset_dir=self.dataset_dir,
25 | split="train",
26 | )
27 |
28 | def set_val_dataset(self):
29 | self.val_dataset = self.dataset_cls(
30 | dataset_dir=self.dataset_dir,
31 | split="val",
32 | )
33 |
34 | def set_test_dataset(self):
35 | self.test_dataset = self.dataset_cls(
36 | dataset_dir=self.dataset_dir,
37 | split="test",
38 | )
39 |
40 | def setup(self, stage: Optional[str] = None):
41 | if stage in (None, "fit"):
42 | self.set_train_dataset()
43 | self.set_val_dataset()
44 |
45 | if stage in (None, "test"):
46 | self.set_test_dataset()
47 |
48 | self.collate = partial(self.dataset_cls.collate, max_len=self.max_len)
49 |
50 | def train_dataloader(self):
51 | return DataLoader(
52 | self.train_dataset,
53 | batch_size=self.batch_size,
54 | num_workers=self.num_workers,
55 | collate_fn=self.collate,
56 | drop_last=True,
57 | pin_memory=True,
58 | )
59 |
60 | def val_dataloader(self):
61 | return DataLoader(
62 | self.val_dataset,
63 | batch_size=self.batch_size,
64 | num_workers=self.num_workers,
65 | collate_fn=self.collate,
66 | drop_last=True,
67 | pin_memory=True,
68 | )
69 |
70 | def test_dataloader(self):
71 | return DataLoader(
72 | self.test_dataset,
73 | batch_size=self.batch_size,
74 | num_workers=self.num_workers,
75 | collate_fn=self.collate,
76 | drop_last=True,
77 | pin_memory=True,
78 | )
79 |
--------------------------------------------------------------------------------
/mofreinforce/predictor/dataset.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import json
3 | from pathlib import Path
4 | from torch.utils.data import Dataset
5 |
6 |
7 | class MOFDataset(Dataset):
8 | def __init__(
9 | self,
10 | dataset_dir,
11 | split,
12 | ):
13 | assert split in ["train", "test", "val"]
14 |
15 | # load dict_mof
16 | path_dict_mof = Path(dataset_dir, f"{split}.json")
17 | print(f"read file : {path_dict_mof}")
18 | self.dict_mof = json.load(open(path_dict_mof, "r"))
19 | self.mof_name = list(self.dict_mof.keys())
20 |
21 | def __len__(self):
22 | return len(self.mof_name)
23 |
24 | def __getitem__(self, idx):
25 | ret = dict()
26 | mof_name = self.mof_name[idx]
27 |
28 | ret.update(
29 | {
30 | "mof_name": mof_name,
31 | }
32 | )
33 | ret.update(self.dict_mof[mof_name])
34 | return ret
35 |
36 | @staticmethod
37 | def collate(batch, max_len):
38 | keys = set([key for b in batch for key in b.keys()])
39 | dict_batch = {k: [dic[k] if k in dic else None for dic in batch] for k in keys}
40 | # target
41 | dict_batch["target"] = torch.FloatTensor(dict_batch["target"])
42 | # mc (idx)
43 | dict_batch["mc"] = torch.LongTensor(dict_batch["mc"])
44 | # topo (idx)
45 | dict_batch["topo"] = torch.LongTensor(dict_batch["topo"])
46 | # ol (selfies)
47 | ol_len = max_len - 3 # cls, mc, topo
48 | dict_batch["ol"] = torch.LongTensor(
49 | [
50 | ol + [0] * (ol_len - len(ol)) if len(ol) < ol_len else ol[:ol_len]
51 | for ol in dict_batch["ol"]
52 | ]
53 | )
54 | return dict_batch
55 |
--------------------------------------------------------------------------------
/mofreinforce/predictor/gadgets.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torchmetrics import Metric
3 |
4 |
5 | class Accuracy(Metric):
6 | def __init__(self, dist_sync_on_step=False):
7 | super().__init__(dist_sync_on_step=dist_sync_on_step)
8 | self.add_state("correct", default=torch.tensor(0.0), dist_reduce_fx="sum")
9 | self.add_state("total", default=torch.tensor(0.0), dist_reduce_fx="sum")
10 |
11 | def update(self, logits, target):
12 | logits, target = (
13 | logits.detach().to(self.correct.device),
14 | target.detach().to(self.correct.device),
15 | )
16 | if len(logits.shape) > 1:
17 | preds = logits.argmax(dim=-1)
18 | else:
19 | # binary accuracy
20 | logits[logits >= 0.5] = 1
21 | logits[logits < 0.5] = 0
22 | preds = logits
23 |
24 | preds = preds[target != -100]
25 | target = target[target != -100]
26 |
27 | if target.numel() == 0:
28 | return 1
29 |
30 | assert preds.shape == target.shape
31 |
32 | self.correct += torch.sum(preds == target)
33 | self.total += target.numel()
34 |
35 | def compute(self):
36 | return self.correct / self.total
37 |
38 |
39 | class Scalar(Metric):
40 | def __init__(self, dist_sync_on_step=False):
41 | super().__init__(dist_sync_on_step=dist_sync_on_step)
42 | self.add_state("scalar", default=torch.tensor(0.0), dist_reduce_fx="sum")
43 | self.add_state("total", default=torch.tensor(0.0), dist_reduce_fx="sum")
44 |
45 | def update(self, scalar):
46 | if isinstance(scalar, torch.Tensor):
47 | scalar = scalar.detach().to(self.scalar.device)
48 | else:
49 | scalar = torch.tensor(scalar).float().to(self.scalar.device)
50 | self.scalar += scalar
51 | self.total += 1
52 |
53 | def compute(self):
54 | return self.scalar / self.total
--------------------------------------------------------------------------------
/mofreinforce/predictor/heads.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | class Pooler(nn.Module):
5 | def __init__(self, hidden_size, index=0):
6 | super().__init__()
7 | self.dense = nn.Linear(hidden_size, hidden_size)
8 | self.activation = nn.Tanh()
9 | self.index = index
10 |
11 | def forward(self, hidden_states):
12 | first_token_tensor = hidden_states[:, self.index]
13 | pooled_output = self.dense(first_token_tensor)
14 | pooled_output = self.activation(pooled_output)
15 | return pooled_output
16 |
17 |
18 | class RegressionHead(nn.Module):
19 | """
20 | head for Regression
21 | """
22 |
23 | def __init__(self, hid_dim):
24 | super().__init__()
25 | self.fc = nn.Linear(hid_dim, 1)
26 |
27 | def forward(self, x):
28 | x = self.fc(x)
29 | return x
30 |
31 |
32 | class ClassificationHead(nn.Module):
33 | """
34 | head for Classification
35 | """
36 |
37 | def __init__(self, hid_dim, n_classes):
38 | super().__init__()
39 |
40 | if n_classes == 2:
41 | self.fc = nn.Linear(hid_dim, 1)
42 | self.binary = True
43 | else:
44 | self.fc = nn.Linear(hid_dim, n_classes)
45 | self.binary = False
46 |
47 | def forward(self, x):
48 | x = self.fc(x)
49 |
50 | return x, self.binary
51 |
--------------------------------------------------------------------------------
/mofreinforce/predictor/module.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from pytorch_lightning import LightningModule
4 |
5 | from predictor import objectives, heads
6 |
7 | from predictor.transformer import Transformer
8 | from utils import module_utils
9 |
10 | from torchmetrics.functional import r2_score
11 |
12 |
13 | class Predictor(LightningModule):
14 | def __init__(self, config):
15 | super().__init__()
16 | self.save_hyperparameters()
17 | # build transformer
18 | self.transformer = Transformer(
19 | embed_dim=config["hid_dim"],
20 | depth=config["num_layers"],
21 | num_heads=config["num_heads"],
22 | mlp_ratio=config["mlp_ratio"],
23 | drop_rate=config["drop_rate"],
24 | )
25 |
26 | # metal node embedding
27 | self.mc_embedding = nn.Embedding(config["mc_dim"], config["hid_dim"])
28 | self.mc_embedding.apply(module_utils.init_weights)
29 |
30 | # topology embedding
31 | self.topo_embedding = nn.Embedding(config["topo_dim"], config["hid_dim"])
32 | self.topo_embedding.apply(module_utils.init_weights)
33 |
34 | # organic linker embedding
35 | self.ol_embedding = nn.Embedding(
36 | config["vocab_dim"], config["hid_dim"], padding_idx=0
37 | )
38 | self.ol_embedding.apply(module_utils.init_weights)
39 |
40 | # class token
41 | self.cls_embeddings = nn.Linear(1, config["hid_dim"])
42 | self.cls_embeddings.apply(module_utils.init_weights)
43 |
44 | # position embedding
45 | # max_len = ol_max_len (100) + cls + mc + topo
46 | self.pos_embeddings = nn.Parameter(
47 | torch.zeros(1, config["max_len"], config["hid_dim"])
48 | )
49 | # self.pos_embeddings.apply(module_utils.init_weights)
50 |
51 | # pooler
52 | self.pooler = heads.Pooler(config["hid_dim"])
53 | self.pooler.apply(module_utils.init_weights)
54 |
55 | # ===================== loss =====================
56 |
57 | # regression
58 | if self.hparams.config["loss_names"]["regression"] > 0:
59 | self.regression_head = heads.RegressionHead(config["hid_dim"])
60 | self.regression_head.apply(module_utils.init_weights)
61 | # normalization
62 | self.mean = config["mean"]
63 | self.std = config["std"]
64 | self.normalizer = module_utils.Normalizer(self.mean, self.std)
65 |
66 | module_utils.set_metrics(self)
67 | module_utils.set_task(self)
68 | # ===================== load downstream (test_only) ======================
69 |
70 | if config["load_path"] != "" and config["test_only"]:
71 | ckpt = torch.load(config["load_path"], map_location="cpu")
72 | state_dict = ckpt["state_dict"]
73 | self.load_state_dict(state_dict, strict=False)
74 | print(f"load model : {config['load_path']}")
75 |
76 | def infer(self, batch):
77 | mc = batch["mc"] # [B]
78 | topo = batch["topo"] # [B]
79 | ol = batch["ol"] # [B, ol_len]
80 | batch_size = len(mc)
81 |
82 | mc_embeds = self.mc_embedding(mc).unsqueeze(1) # [B, 1, hid_dim]
83 | topo_embeds = self.topo_embedding(topo).unsqueeze(1) # [B, 1, hid_dim]
84 |
85 | ol_embeds = self.ol_embedding(ol) # [B, ol_len, hid_dim]
86 |
87 | cls_tokens = torch.zeros(batch_size).to(ol_embeds) # [B]
88 | cls_embeds = self.cls_embeddings(cls_tokens[:, None, None]) # [B, 1, hid_dim]
89 |
90 | # total_embedding and mask
91 | co_embeds = torch.cat(
92 | [cls_embeds, mc_embeds, topo_embeds, ol_embeds], dim=1
93 | ) # [B, max_len, hid_dim]
94 | co_masks = torch.cat(
95 | [torch.ones([batch_size, 3]).to(ol), (ol != 0).float()], dim=1
96 | )
97 |
98 | # add pos_embeddings
99 | final_embeds = co_embeds + self.pos_embeddings
100 | final_embeds = self.transformer.pos_drop(final_embeds)
101 |
102 | # transformer blocks
103 | x = final_embeds
104 | for i, blk in enumerate(self.transformer.blocks):
105 | x, _attn = blk(x, mask=co_masks)
106 |
107 | x = self.transformer.norm(x)
108 |
109 | cls_feats = self.pooler(x)
110 |
111 | ret = {
112 | "topo_name": batch.get("topo_name", None),
113 | "mc_name": batch.get("mc_name", None),
114 | "ol_name": batch.get("ol_name", None),
115 | "cls_feats": cls_feats,
116 | "mc": mc,
117 | "topo": topo,
118 | "ol": ol,
119 | "output": x,
120 | "output_mask": co_masks,
121 | }
122 | return ret
123 |
124 | def forward(self, batch):
125 | ret = dict()
126 |
127 | if len(self.current_tasks) == 0:
128 | ret.update(self.infer(batch))
129 |
130 | # regression
131 | if "regression" in self.current_tasks:
132 | ret.update(objectives.compute_regression(self, batch, self.normalizer))
133 |
134 | return ret
135 |
136 | def training_step(self, batch, batch_idx):
137 | module_utils.set_task(self)
138 | output = self(batch)
139 | total_loss = sum([v for k, v in output.items() if "loss" in k])
140 |
141 | return total_loss
142 |
143 | def training_epoch_end(self, outputs):
144 | module_utils.epoch_wrapup(self)
145 |
146 | def validation_step(self, batch, batch_idx):
147 | module_utils.set_task(self)
148 | output = self(batch)
149 |
150 | def validation_epoch_end(self, outputs):
151 | module_utils.epoch_wrapup(self)
152 |
153 | def test_step(self, batch, batch_idx):
154 | module_utils.set_task(self)
155 | output = self(batch)
156 | output = {
157 | k: (v.cpu() if torch.is_tensor(v) else v) for k, v in output.items()
158 | } # update cpu for memory
159 | return output
160 |
161 | def test_epoch_end(self, outputs):
162 | module_utils.epoch_wrapup(self)
163 |
164 | # calculate r2 score when regression
165 | if "regression_logits" in outputs[0].keys():
166 | logits = []
167 | labels = []
168 | for out in outputs:
169 | logits += out["regression_logits"].tolist()
170 | labels += out["regression_labels"].tolist()
171 | r2 = r2_score(torch.FloatTensor(logits), torch.FloatTensor(labels))
172 | self.log(f"test/r2_score", r2)
173 |
174 | """
175 | # example for saving cls_feats
176 | import json
177 | d = {}
178 | for out in outputs:
179 | mc_name = out["mc_name"]
180 | ol_name = out["ol_name"]
181 | topo_name = out["topo_name"]
182 | cls_feats = out["cls_feats"]
183 | for mc, ol, topos, cls_feat in zip(mc_name, ol_name, topo_name, cls_feats):
184 | cif_id = "_".join([str(topos), str(mc), str(ol)])
185 | d.update({cif_id : cls_feat.tolist()})
186 | json.dump(d, open("cls_feats.json", "w"))
187 | """
188 |
189 | def configure_optimizers(self):
190 | return module_utils.set_schedule(self)
191 |
--------------------------------------------------------------------------------
/mofreinforce/predictor/objectives.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torchmetrics.functional import r2_score
4 |
5 |
6 | def weighted_mse_loss(logits, target, weight):
7 | return (weight * (logits - target) ** 2).mean()
8 |
9 |
10 | def compute_regression(pl_module, batch, normalizer):
11 | infer = pl_module.infer(batch)
12 | batch_size = pl_module.hparams.config["batch_size"]
13 |
14 | logits = pl_module.regression_head(infer["cls_feats"]).squeeze(-1) # [[B]]
15 | target = batch["target"] # [B]
16 |
17 | # normalize encode if config["mean"] and config["std], else pass
18 | target = normalizer.encode(target)
19 |
20 | loss = F.mse_loss(logits, target)
21 |
22 | ret = {
23 | "regression_loss": loss,
24 | "regression_logits": normalizer.decode(logits),
25 | "regression_labels": normalizer.decode(target),
26 | }
27 | ret.update(infer)
28 |
29 | # call update() loss and acc
30 | phase = "train" if pl_module.training else "val"
31 | loss = getattr(pl_module, f"{phase}_regression_loss")(ret["regression_loss"])
32 | mae = getattr(pl_module, f"{phase}_regression_mae")(
33 | F.l1_loss(ret["regression_logits"], ret["regression_labels"])
34 | )
35 |
36 | r2 = getattr(pl_module, f"{phase}_regression_r2")(r2_score(logits, target))
37 |
38 | pl_module.log(
39 | f"regression/{phase}/loss", loss, batch_size=batch_size, prog_bar=True
40 | )
41 | pl_module.log(f"regression/{phase}/mae", mae, batch_size=batch_size, prog_bar=True)
42 | pl_module.log(f"regression/{phase}/r2", r2, batch_size=batch_size, prog_bar=True)
43 | return ret
44 |
--------------------------------------------------------------------------------
/mofreinforce/predictor/transformer.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 |
3 | import torch
4 | import torch.nn as nn
5 | from timm.models.layers import DropPath, trunc_normal_
6 |
7 |
8 | class Mlp(nn.Module):
9 | def __init__(
10 | self,
11 | in_features,
12 | hidden_features=None,
13 | out_features=None,
14 | act_layer=nn.GELU,
15 | drop=0.0,
16 | ):
17 | super().__init__()
18 | out_features = out_features or in_features
19 | hidden_features = hidden_features or in_features
20 | self.fc1 = nn.Linear(in_features, hidden_features)
21 | self.act = act_layer()
22 | self.fc2 = nn.Linear(hidden_features, out_features)
23 | self.drop = nn.Dropout(drop)
24 |
25 | def forward(self, x):
26 | x = self.fc1(x)
27 | x = self.act(x)
28 | x = self.drop(x)
29 | x = self.fc2(x)
30 | x = self.drop(x)
31 | return x
32 |
33 |
34 | class Attention(nn.Module):
35 | def __init__(
36 | self,
37 | dim,
38 | num_heads=8,
39 | qkv_bias=False,
40 | qk_scale=None,
41 | attn_drop=0.0,
42 | proj_drop=0.0,
43 | ):
44 | super().__init__()
45 | self.num_heads = num_heads
46 | head_dim = dim // num_heads
47 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
48 | self.scale = qk_scale or head_dim**-0.5
49 |
50 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
51 | self.attn_drop = nn.Dropout(attn_drop)
52 | self.proj = nn.Linear(dim, dim)
53 | self.proj_drop = nn.Dropout(proj_drop)
54 |
55 | def forward(self, x, mask=None):
56 | B, N, C = x.shape
57 | assert C % self.num_heads == 0
58 | qkv = (
59 | self.qkv(x) # [B, N, 3*C]
60 | .reshape(
61 | B, N, 3, self.num_heads, C // self.num_heads
62 | ) # [B, N, 3, num_heads, C//num_heads]
63 | .permute(2, 0, 3, 1, 4) # [3, B, num_heads, N, C//num_heads]
64 | )
65 | q, k, v = (
66 | qkv[0], # [B, num_heads, N, C//num_heads]
67 | qkv[1], # [B, num_heads, N, C//num_heads]
68 | qkv[2], # [B, num_heads, N, C//num_heads]
69 | ) # make torchscript happy (cannot use tensor as tuple)
70 |
71 | attn = (q @ k.transpose(-2, -1)) * self.scale # [B, num_heads, N, N]
72 | if mask is not None:
73 | mask = mask.bool()
74 | attn = attn.masked_fill(~mask[:, None, None, :], float("-inf"))
75 | attn = attn.softmax(dim=-1) # [B, num_heads, N, N]
76 | attn = self.attn_drop(attn)
77 |
78 | x = (
79 | (attn @ v).transpose(1, 2).reshape(B, N, C)
80 | ) # [B, num_heads, N, C//num_heads] -> [B, N, C]
81 | x = self.proj(x)
82 | x = self.proj_drop(x)
83 | return x, attn
84 |
85 |
86 | class Block(nn.Module):
87 | def __init__(
88 | self,
89 | dim,
90 | num_heads,
91 | mlp_ratio=4.0,
92 | qkv_bias=False,
93 | qk_scale=None,
94 | drop=0.0,
95 | attn_drop=0.0,
96 | drop_path=0.0,
97 | act_layer=nn.GELU,
98 | norm_layer=nn.LayerNorm,
99 | ):
100 | super().__init__()
101 | self.norm1 = norm_layer(dim)
102 | self.attn = Attention(
103 | dim,
104 | num_heads=num_heads,
105 | qkv_bias=qkv_bias,
106 | qk_scale=qk_scale,
107 | attn_drop=attn_drop,
108 | proj_drop=drop,
109 | )
110 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
111 | self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
112 | self.norm2 = norm_layer(dim)
113 | mlp_hidden_dim = int(dim * mlp_ratio)
114 | self.mlp = Mlp(
115 | in_features=dim,
116 | hidden_features=mlp_hidden_dim,
117 | act_layer=act_layer,
118 | drop=drop,
119 | )
120 |
121 | def forward(self, x, mask=None):
122 | _x, attn = self.attn(self.norm1(x), mask=mask)
123 | x = x + self.drop_path(_x)
124 | x = x + self.drop_path(self.mlp(self.norm2(x)))
125 | return x, attn
126 |
127 |
128 | class Transformer(nn.Module):
129 | def __init__(
130 | self,
131 | embed_dim,
132 | depth=12,
133 | num_heads=12,
134 | mlp_ratio=4.0,
135 | qkv_bias=True,
136 | qk_scale=None,
137 | drop_rate=0.0,
138 | attn_drop_rate=0.0,
139 | drop_path_rate=0.0,
140 | norm_layer=None,
141 | add_norm_before_transformer=False,
142 | config=None,
143 | ):
144 | super().__init__()
145 |
146 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
147 | self.add_norm_before_transformer = add_norm_before_transformer
148 | if add_norm_before_transformer:
149 | self.pre_norm = norm_layer(embed_dim)
150 |
151 | dpr = [
152 | x.item() for x in torch.linspace(0, drop_path_rate, depth)
153 | ] # stochastic depth decay rule
154 | self.blocks = nn.ModuleList(
155 | [
156 | Block(
157 | dim=embed_dim,
158 | num_heads=num_heads,
159 | mlp_ratio=mlp_ratio,
160 | qkv_bias=qkv_bias,
161 | qk_scale=qk_scale,
162 | drop=drop_rate,
163 | attn_drop=attn_drop_rate,
164 | drop_path=dpr[i],
165 | norm_layer=norm_layer,
166 | )
167 | for i in range(depth)
168 | ]
169 | )
170 | self.norm = norm_layer(embed_dim)
171 |
172 | # trunc_normal_(self.cls_token, std=0.02)
173 | self.pos_drop = nn.Dropout(p=drop_rate)
174 |
175 | self.apply(self._init_weights)
176 |
177 | def _init_weights(self, m):
178 | if isinstance(m, nn.Linear):
179 | trunc_normal_(m.weight, std=0.02)
180 | if isinstance(m, nn.Linear) and m.bias is not None:
181 | nn.init.constant_(m.bias, 0)
182 | elif isinstance(m, nn.LayerNorm):
183 | nn.init.constant_(m.bias, 0)
184 | nn.init.constant_(m.weight, 1.0)
185 |
--------------------------------------------------------------------------------
/mofreinforce/reinforce/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hspark1212/MOFreinforce/b920084a29db9223482ddfeffd2e775b82f49e63/mofreinforce/reinforce/__init__.py
--------------------------------------------------------------------------------
/mofreinforce/reinforce/config_reinforce.py:
--------------------------------------------------------------------------------
1 | from sacred import Experiment
2 |
3 | ex = Experiment("reinforce")
4 |
5 |
6 | @ex.config
7 | def config():
8 | seed = 0
9 | exp_name = ""
10 | gpu_idx = 0
11 |
12 | # load predictor
13 | predictor_load_path = ["model/predictor_qkh.ckpt"]
14 | mean = [None]
15 | std = [None]
16 |
17 | # get reward
18 | threshold = False
19 | """
20 | (1) threshold == True
21 | if |pred| >= |reward_max|, reward = 1, else : reward = |pred| / |reward_max|
22 | (2) threshold == False
23 | reward = |pred| / |reward_max|
24 | """
25 | reward_max = [1.0] # if you want to normalize by max reward.
26 |
27 | # load generator
28 | dataset_dir = "data/dataset_generator"
29 | generator_load_path = "model/generator/generator.ckpt"
30 |
31 | # REINFORCE
32 | lr = 1e-4
33 | decay_ol = 1.0
34 | decay_topo = 1.0
35 | scheduler = "constant" # warmup = get_linear_schedule_with_warmup, constant = get_constant_schedule
36 | early_stop = 0.5 # early_stop when the accuracy of scaffold is less than it.
37 | ratio_exploit = 0.0 # ratio for exploitation (freeze)
38 | """ to decrease effect of topology when calculating loss.
39 | Given topology is the discrete feature, its loss could much larger than the sum of organic linker.
40 | It gives rise to training failure.
41 | """
42 | ratio_mask_mc = 0.0 # ratio for masking mc of input_src
43 |
44 | # Trainer
45 | max_epochs = 20
46 | batch_size = 16
47 | accumulate_grad_batches = 2
48 | devices = 1
49 | num_nodes = 1
50 | precision = 16
51 | resume_from = None
52 | limit_train_batches = 500
53 | limit_val_batches = 30
54 | val_check_interval = 1.0
55 | test_only = False
56 | load_path = ""
57 | gradient_clip_val = None
58 |
59 | # tensorboard
60 | log_dir = "reinforce/logs"
61 |
62 |
63 | @ex.named_config
64 | def test():
65 | exp_name = "test"
66 |
67 |
68 | """
69 | v0 Q_kH
70 | """
71 |
72 |
73 | @ex.named_config
74 | def v0_scratch():
75 | exp_name = "v0_scratch"
76 | test_only = True
77 |
78 | # reward
79 | reward_max = [-60.0]
80 |
81 | # predictor
82 | predictor_load_path = ["model/predictor/best_predictor_v0_qkh_round1.ckpt"]
83 | mean = [-20.331]
84 | std = [-10.383]
85 |
86 |
87 | @ex.named_config
88 | def v0_qkh_round1():
89 | """
90 | omit mc in the input
91 | """
92 | exp_name = "v0_qkh_round1"
93 | max_epochs = 20
94 |
95 | # reward
96 | reward_max = [-60.0]
97 |
98 | # reinforce
99 | early_stop = 0.5 # early_stop when the accuracy of scaffold is less than it.
100 | ratio_exploit = 0.5 # ratio for exploitation
101 | ratio_mask_mc = 0.5 # ratio for masking mc of input_src
102 |
103 | # predictor
104 | predictor_load_path = ["model/predictor/best_predictor_v0_qkh_round1.ckpt"]
105 | mean = [-20.331]
106 | std = [-10.383]
107 |
108 |
109 | @ex.named_config
110 | def v0_qkh_round2():
111 | """
112 | omit mc in the input
113 | """
114 | exp_name = "v0_qkh_round2"
115 | max_epochs = 20
116 |
117 | # reward
118 | reward_max = [-60.0]
119 |
120 | # reinforce
121 | early_stop = 0.5 # early_stop when the accuracy of scaffold is less than it.
122 | ratio_exploit = 0.5 # ratio for exploitation
123 | ratio_mask_mc = 0.5 # ratio for masking mc of input_src
124 |
125 | # predictor
126 | predictor_load_path = ["model/predictor/best_predictor_v0_qkh_round2.ckpt"]
127 | mean = [-21.068]
128 | std = [10.950]
129 |
130 |
131 | @ex.named_config
132 | def v0_qkh_round3():
133 | """
134 | omit mc in the input
135 | """
136 | exp_name = "v0_qkh_round3"
137 | max_epochs = 20
138 |
139 | # reward
140 | reward_max = [-60.0]
141 |
142 | # reinforce
143 | early_stop = 0.5 # early_stop when the accuracy of scaffold is less than it.
144 | ratio_exploit = 0.5 # ratio for exploitation
145 | ratio_mask_mc = 0.5 # ratio for masking mc of input_src
146 |
147 | # predictor
148 | predictor_load_path = ["model/predictor/best_predictor_v0_qkh_round3.ckpt"]
149 | mean = [-21.810]
150 | std = [11.452]
151 |
152 |
153 | """
154 | Selectivity (v1)
155 | """
156 |
157 |
158 | @ex.named_config
159 | def v1_scratch():
160 | exp_name = "v1_scratch"
161 | test_only = True
162 |
163 | # reward
164 | reward_max = [10.0]
165 |
166 | # predictor
167 | predictor_load_path = ["model/predictor/best_predictor_v1_selectivity_round1.ckpt"]
168 | mean = [1.871]
169 | std = [1.922]
170 |
171 |
172 | @ex.named_config
173 | def v1_selectivity_round1():
174 | exp_name = "v1_selectivity_round1"
175 | max_epochs = 20
176 |
177 | # reward
178 | reward_max = [10.0]
179 |
180 | # reinforce
181 | early_stop = 0.5 # early_stop when the accuracy of scaffold is less than it.
182 | ratio_exploit = 0.5 # ratio for exploitation
183 | ratio_mask_mc = 0.5 # ratio for masking mc of input_src
184 |
185 | # predictor
186 | predictor_load_path = ["model/predictor/best_predictor_v1_selectivity_round1.ckpt"]
187 | mean = [1.871]
188 | std = [1.922]
189 |
190 |
191 | @ex.named_config
192 | def v1_selectivity_round2():
193 | exp_name = "v1_selectivity_round2"
194 | max_epochs = 20
195 |
196 | # reward
197 | reward_max = [10.0]
198 |
199 | # reinforce
200 | early_stop = 0.5 # early_stop when the accuracy of scaffold is less than it.
201 | ratio_exploit = 0.5 # ratio for exploitation
202 | ratio_mask_mc = 0.5 # ratio for masking mc of input_src
203 |
204 | # predictor
205 | predictor_load_path = ["model/predictor/best_predictor_v1_selectivity_round2.ckpt"]
206 | mean = [2.085]
207 | std = [2.052]
208 |
209 |
210 | @ex.named_config
211 | def v1_selectivity_round3():
212 | exp_name = "v1_selectivity_round3"
213 | max_epochs = 20
214 |
215 | # reward
216 | reward_max = [10.0]
217 |
218 | # reinforce
219 | early_stop = 0.5 # early_stop when the accuracy of scaffold is less than it.
220 | ratio_exploit = 0.5 # ratio for exploitation
221 | ratio_mask_mc = 0.5 # ratio for masking mc of input_src
222 |
223 | # predictor
224 | predictor_load_path = ["model/predictor/best_predictor_v1_selectivity_round3.ckpt"]
225 | mean = [2.258]
226 | std = [2.128]
227 |
--------------------------------------------------------------------------------
/mofreinforce/reinforce/module.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import copy
4 | import torch
5 | from tqdm import tqdm
6 |
7 | from utils.metrics import Metrics
8 |
9 | from pytorch_lightning import LightningModule
10 | from transformers import get_linear_schedule_with_warmup, get_constant_schedule
11 |
12 | from rdkit import Chem
13 | from rdkit.Chem.Draw import MolToImage
14 |
15 | import numpy as np
16 | import libs.selfies as sf
17 |
18 | topo_to_cn = json.load(open("data/final_topo_cn.json"))
19 | mc_to_cn = json.load(open("data/mc_cn.json"))
20 |
21 |
22 | class Reinforce(LightningModule):
23 | def __init__(self, agent, predictors, config):
24 | super(Reinforce, self).__init__()
25 |
26 | self.agent = copy.deepcopy(agent)
27 | # freeze only encoder
28 | for param in self.agent.transformer.encoder.parameters():
29 | param.requires_grad = False
30 |
31 | self.freeze_agent = copy.deepcopy(agent)
32 | # freeze agent
33 | for param in self.freeze_agent.parameters():
34 | param.requires_grad = False
35 |
36 | self.predictors = predictors
37 |
38 | # reward
39 | self.threshold = config["threshold"]
40 | self.reward_max = config["reward_max"]
41 |
42 | # reinforce
43 | self.lr = config["lr"]
44 | self.decay_ol = config["decay_ol"]
45 | self.scheduler = config["scheduler"]
46 | self.decay_topo = config["decay_topo"]
47 | self.early_stop = config["early_stop"]
48 | self.ratio_exploit = config["ratio_exploit"]
49 | self.ratio_mask_mc = config["ratio_mask_mc"]
50 |
51 | # load model
52 | if config["load_path"] != "":
53 | ckpt = torch.load(config["load_path"], map_location="cpu")
54 | state_dict = ckpt["state_dict"]
55 | self.load_state_dict(state_dict, strict=False)
56 | print(f"load model : {config['load_path']}")
57 |
58 | def forward(self, src, max_len=128):
59 | """
60 | based on evaluate() in Generator
61 | :param src: encoded_input of generator [B=1, seq_len]
62 | :return:
63 | """
64 |
65 | vocab_to_idx = self.agent.vocab_to_idx
66 |
67 | # mask_mc of encoded_src
68 | if torch.randn(1) < self.ratio_mask_mc:
69 | src[:, 0].fill_(0) # 0 is pad_idx
70 |
71 | src_mask = self.agent.transformer.make_src_mask(src)
72 |
73 | enc_src = self.agent.transformer.encoder(src, src_mask) # [B=1, seq_len, hid_dim]
74 |
75 | # get target
76 | trg_indexes = [vocab_to_idx["[SOS]"]]
77 |
78 | list_prob = []
79 |
80 | for i in range(max_len):
81 |
82 | trg_tensor = torch.LongTensor(trg_indexes).unsqueeze(0).to(src.device)
83 | trg_mask = self.agent.transformer.make_trg_mask(trg_tensor) # [B=1, 1, seq_len, seq_len]
84 |
85 | output = self.agent.transformer.decoder(
86 | trg_tensor, enc_src, trg_mask, src_mask) # [B=1, seq_len, vocab_dim]
87 |
88 | with torch.no_grad():
89 | freeze_output = self.freeze_agent.transformer.decoder(
90 | trg_tensor, enc_src, trg_mask, src_mask) # [B=1, seq_len, vocab_dim]
91 | if torch.rand(1)[0] < self.ratio_exploit:
92 | output = freeze_output
93 |
94 | if "output_ol" in output.keys():
95 | out = output["output_ol"]
96 | prob = out.softmax(dim=-1)[0, -1]
97 | m = torch.multinomial(prob, 1).item()
98 | elif "output_mc" in output.keys():
99 | out = output["output_mc"]
100 | prob = out.softmax(dim=-1)[0]
101 | m = torch.multinomial(prob, 1).item()
102 | else:
103 | out = output["output_topo"]
104 | prob = out.softmax(dim=-1)[0]
105 | m = torch.multinomial(prob, 1).item()
106 |
107 | list_prob.append(prob[m])
108 | trg_indexes.append(m)
109 |
110 | if i >= 2 and m == vocab_to_idx["[EOS]"]:
111 | break
112 |
113 | # get topo
114 | topo_idx = trg_indexes[1]
115 | topo = self.agent.idx_to_topo[topo_idx]
116 | # get mc
117 | mc_idx = trg_indexes[2]
118 | mc = self.agent.idx_to_mc[mc_idx]
119 | # get ol
120 | ol_idx = trg_indexes[3:]
121 | ol_tokens = [self.agent.idx_to_vocab[idx] for idx in ol_idx]
122 |
123 | # check matching connection points
124 | topo_cn = topo_to_cn.get(topo, [0]) # if topo is [PAD] then topo_cn = [0]
125 | if len(topo_cn) == 1:
126 | topo_cn.append(2)
127 | mc_cn = mc_to_cn.get(mc, -1) # if mc is [PAD] then mc_cn = -1
128 | ol_cn = ol_idx.count(self.agent.vocab_to_idx["[*]"])
129 |
130 | # convert to selfies and smiles
131 | try:
132 | gen_sf = "".join(ol_tokens[:-1]) # remove EOS token
133 | gen_sm = sf.decoder(gen_sf)
134 | new_gen_sf = sf.encoder(gen_sm) # to check valid SELFIES
135 | assert ol_tokens[-1] == "[EOS]", Exception("The last token is not [EOS]")
136 | assert new_gen_sf == gen_sf, Exception("SELFIES error")
137 | assert set(topo_cn) == set([mc_cn, ol_cn]), Exception("connection points error")
138 | except Exception as e:
139 | print(e)
140 | # print(f"The failed gen_sf : {gen_sf}")
141 | # print(f"The failed gen_sm : {gen_sm}")
142 | gen_sf = None
143 | gen_sm = None
144 |
145 | ret = {
146 | "topo": topo,
147 | "mc": mc,
148 | "topo_idx": topo_idx,
149 | "mc_idx": mc_idx,
150 | "ol_idx": ol_idx,
151 | "gen_sf": gen_sf,
152 | "gen_sm": gen_sm,
153 | "list_prob": list_prob,
154 | }
155 | return ret
156 |
157 | def configure_optimizers(self):
158 | optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
159 | max_steps = self.trainer.estimated_stepping_batches
160 | print(f"mas_steps : {max_steps}")
161 |
162 | if self.scheduler == "warmup":
163 | scheduler = get_linear_schedule_with_warmup(
164 | optimizer=optimizer,
165 | num_warmup_steps=int(self.trainer.max_epochs * 0.1),
166 | num_training_steps=self.trainer.max_epochs,
167 | )
168 | elif self.scheduler == "constant":
169 | scheduler = get_constant_schedule(
170 | optimizer=optimizer,
171 | )
172 | return [optimizer], [scheduler]
173 |
174 | def training_step(self, batch, batch_idx):
175 | self.agent.train()
176 |
177 | rl_loss = 0
178 | total_reward = 0
179 | num_fail = 0
180 |
181 | for enc_input in batch["encoded_input"]:
182 | src = enc_input.unsqueeze(0)
183 | out = self(src)
184 | if out["gen_sm"] is None:
185 | num_fail += 1
186 | continue
187 |
188 | # get reward from predictor
189 | with torch.no_grad():
190 | rewards, preds = self.get_reward(output_generator=out)
191 | reward = sum(rewards)
192 |
193 | total_reward += reward
194 |
195 | # Reinforce algorithm
196 | discounted_reward = reward
197 | tmp_loss = []
198 |
199 | # topology, metal cluster (the probability of metal cluster is almost zero, because it is used as input)
200 | for prob in out["list_prob"][:2]:
201 | rl_loss -= torch.log(prob) * reward * self.decay_topo
202 | tmp_loss.append(-torch.log(prob).item() * reward)
203 | # organic linker (discounter_reward is applied only for organic linkers)
204 | for i, prob in enumerate(out["list_prob"][2:][::-1]):
205 | rl_loss -= torch.log(prob) * discounted_reward
206 | discounted_reward = discounted_reward * self.decay_ol
207 | tmp_loss.append(-torch.log(prob).item() * discounted_reward)
208 |
209 | print(rewards, preds, out["topo"], out["mc"], out["gen_sm"], out["gen_sf"], src[0, 1].item())
210 | print(rl_loss, tmp_loss[:2], sum(tmp_loss[2:]))
211 |
212 | n_batch = batch["encoded_input"].shape[0]
213 | rl_loss = rl_loss / n_batch
214 | total_reward = total_reward / n_batch
215 | num_fail = torch.tensor(num_fail) / n_batch
216 |
217 | self.log("train/rl_loss", rl_loss)
218 | self.log("train/reward", total_reward, prog_bar=True)
219 | self.log("train/num_fail", num_fail, prog_bar=True)
220 |
221 | return rl_loss
222 |
223 | def validation_step(self, batch, batch_idx):
224 | return batch
225 |
226 | def validation_epoch_end(self, batches):
227 | list_src = torch.concat([b["encoded_input"] for b in batches], dim=0)
228 | metrics = Metrics(self.agent.vocab_to_idx, self.agent.idx_to_vocab)
229 | split = "val"
230 |
231 | metrics = self.update_metrics(list_src, metrics, split)
232 | if metrics.get_mean(metrics.scaffold) < self.early_stop:
233 | raise Exception(f"EarlyStopping for scaffold : scaffold accuracy is less than {self.early_stop}")
234 |
235 | def test_step(self, batch, batch_idx):
236 | return batch
237 |
238 | def test_epoch_end(self, batches):
239 | list_src = torch.concat([b["encoded_input"] for b in batches], dim=0)
240 | metrics = Metrics(self.agent.vocab_to_idx, self.agent.idx_to_vocab)
241 | split = "test"
242 |
243 | metrics = self.update_metrics(list_src, metrics, split)
244 | # save results
245 | ret = {
246 | "rewards": metrics.rewards,
247 | "preds": metrics.preds,
248 | "gen_sms": metrics.gen_ol,
249 | "gen_mcs": metrics.gen_mc,
250 | "gen_topos": metrics.gen_topo,
251 | }
252 | path = os.path.join(self.logger.save_dir, f"results_{self.logger.name}.json")
253 | json.dump(ret, open(path, "w"))
254 |
255 | def get_reward(self,
256 | output_generator,
257 | sos_idx=1,
258 | pad_idx=0,
259 | max_len=128,
260 | ):
261 | # create the input of predictor
262 | device = self.predictors[0].device
263 | batch = {}
264 | batch["mc"] = torch.LongTensor([output_generator["mc_idx"]]).to(device)
265 | batch["topo"] = torch.LongTensor([output_generator["topo_idx"]]).to(device)
266 | batch["ol"] = torch.LongTensor([[sos_idx] + # add sos_idx
267 | output_generator["ol_idx"][:max_len - 4] +
268 | [pad_idx] * (max_len - 4 - len(output_generator["ol_idx"]))]).to(
269 | device) # add pad_idx
270 |
271 | preds = []
272 | rewards = []
273 | for i, predictor in enumerate(self.predictors):
274 | infer = predictor.infer(batch)
275 | p = predictor.regression_head(infer["cls_feats"])
276 | p = predictor.normalizer.decode(p)
277 | p = p.detach().item()
278 | preds.append(p)
279 |
280 | # get reward
281 | if self.threshold:
282 | if abs(p) >= abs(self.reward_max[i]):
283 | r = 1
284 | else:
285 | r = 0
286 | rewards.append(r)
287 | else:
288 | rewards.append(p / self.reward_max[i])
289 |
290 | if self.threshold:
291 | if all(rewards):
292 | rewards = [1] * len(rewards)
293 | else:
294 | rewards = [0] * len(rewards)
295 | return rewards, preds
296 |
297 | def update_metrics(self, list_src, metrics, split, num_log_mols=16):
298 | """
299 |
300 | :param list_src: list of source
301 | :param metrics: dictionary of metrics
302 | :param split: str, split
303 | :param num_log_mols: number of molecules for log, which will be saved in tensorboard
304 | :return:
305 | :rtype:
306 | """
307 | for src in tqdm(list_src):
308 | out = self(src.unsqueeze(0))
309 |
310 | if out["gen_sm"] is None:
311 | metrics.num_fail.append(1)
312 | continue
313 | else:
314 | metrics.num_fail.append(0)
315 |
316 | with torch.no_grad():
317 | rewards, preds = self.get_reward(output_generator=out)
318 |
319 | metrics.update(out, src, rewards=rewards, preds=preds)
320 |
321 | self.log(f"{split}/conn_match", metrics.get_mean(metrics.conn_match))
322 | self.log(f"{split}/unique_ol", len(set(metrics.gen_ol)) / len(metrics.gen_ol))
323 | self.log(f"{split}/unique_topo_mc", len(set(zip(metrics.gen_topo, metrics.gen_mc))) / len(metrics.gen_topo))
324 | self.log(f"{split}/scaffold", metrics.get_mean(metrics.scaffold))
325 |
326 | self.log(f"{split}/total_reward", metrics.get_mean(metrics.rewards))
327 | for i, reward in enumerate(zip(*metrics.rewards)):
328 | self.log(f"{split}/reward_{i}", metrics.get_mean(reward))
329 | for i, pred in enumerate(zip(*metrics.preds)):
330 | self.log(f"{split}/target_{i}", metrics.get_mean(pred))
331 | self.log(f"{split}/num_fail", metrics.get_mean(metrics.num_fail))
332 |
333 | # add image to log
334 | for i in range(num_log_mols):
335 | ol = metrics.gen_ol[i]
336 | frags = metrics.input_frags[i]
337 | imgs = []
338 | for s in [ol] + frags:
339 | m = Chem.MolFromSmiles(s)
340 | if not m:
341 | continue
342 | img = MolToImage(m)
343 | img = np.array(img)
344 | img = torch.tensor(img)
345 | imgs.append(img)
346 | imgs = np.stack(imgs, axis=0)
347 | self.logger.experiment.add_image(f"{split}/{i}", imgs, self.global_step, dataformats="NHWC")
348 |
349 | # total gen_ol
350 | imgs = []
351 | for i, m in enumerate(metrics.gen_ol[:num_log_mols]):
352 | try:
353 | m = Chem.MolFromSmiles(m)
354 | img = MolToImage(m)
355 | img = np.array(img)
356 | img = torch.tensor(img)
357 | imgs.append(img)
358 | except Exception as e:
359 | print(e)
360 | imgs = np.stack(imgs, axis=0)
361 | self.logger.experiment.add_image(f"{split}/gen_ol/", imgs, self.global_step, dataformats="NHWC")
362 |
363 | return metrics
364 |
--------------------------------------------------------------------------------
/mofreinforce/reinforce/reinforce.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | from functools import partial
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | from torch.distributions import Categorical
9 | from torch.utils.tensorboard.writer import SummaryWriter
10 |
11 | from rdkit import Chem
12 | from rdkit.Chem.Draw import MolToImage
13 |
14 | import matplotlib.pyplot as plt
15 | import seaborn as sns
16 | import numpy as np
17 | import libs.selfies as sf
18 |
19 | class Reinforce(object):
20 | def __init__(self,
21 | generator,
22 | rmsd_predictor,
23 | target_predictor,
24 | get_reward_rmsd,
25 | get_reward_target,
26 | vocab_to_idx,
27 | mc_to_idx,
28 | topo_to_idx,
29 | emb_dim=128,
30 | hid_dim=128,
31 | config=None,
32 | ):
33 | """
34 | REINFORCE algorithm to generate MOFs to maximize reward functions (rsmsd and target)
35 |
36 | Parameters;
37 | generator_v0 (nn.Module): generator_v0
38 | rmsd_predictor (nn.Module): predictor for rmsd
39 | target_predictor (nn.Module): predictor for target
40 | get_reward_rmsd (function): get rewards for rmsd
41 | get_rewad_target (function): get rewards for target
42 | vocab_to_idx (dictionary) dictionary for token of selfies with index
43 | mc_to_idx (dictionary): dictionary for metal cluster with index
44 | topo_to_idx (dictionary): dictionary for topology with index
45 | emb_dim (int): dimension of embedding for metal cluster and topology
46 | hid_dim (int): dimension of hidden for metal culster and topology
47 | """
48 | super(Reinforce, self).__init__()
49 | # save config
50 | if config is not None:
51 | json.dump(config, open(f"{self.logger.save_dir}/{self.logger.name}/hparams.json", "w"))
52 |
53 | self.generator = generator
54 | self.rmsd_predictor = rmsd_predictor
55 | self.target_predictor = target_predictor
56 | self.get_reward_rmsd = partial(get_reward_rmsd,
57 | criterion=config["criterion_rmsd"],
58 | reward_positive=config["reward_positive_rmsd"],
59 | reward_negative=config["reward_negative_rmsd"],
60 | )
61 | self.get_reward_target = partial(get_reward_target,
62 | criterion=config["criterion_target"],
63 | reward_positive=config["reward_positive_target"],
64 | reward_negative=config["reward_negative_target"],
65 | )
66 |
67 | self.vocab_to_idx = vocab_to_idx
68 | self.mc_to_idx = mc_to_idx
69 | self.topo_to_idx = topo_to_idx
70 | self.emb_dim = emb_dim
71 | self.hid_dim = hid_dim
72 |
73 | self.rmsd_predictor.eval()
74 | self.target_predictor.eval()
75 | self.upgrade_generator()
76 |
77 | self.log_dir = f"{config['log_dir']}/{config['exp_name']}"
78 | if os.path.exists(self.log_dir) and config["test_only"] is not True:
79 | raise Exception(f"{self.log_dir} is already exist")
80 |
81 | self.writer = SummaryWriter(log_dir=self.log_dir)
82 |
83 | def upgrade_generator(self):
84 | """
85 | add metal clutser and topology to generator_v0
86 | """
87 | self.generator.embedding_topo = nn.Embedding(len(self.topo_to_idx), self.emb_dim)
88 | self.generator.rc_mc_1 = nn.Linear(self.emb_dim, self.hid_dim)
89 | self.generator.rc_mc_2 = nn.Linear(self.hid_dim, len(self.mc_to_idx))
90 |
91 | self.generator.optimizer = self.generator.optimizer_instance(self.generator.parameters(),
92 | lr=1e-3,
93 | weight_decay=0.00001)
94 | self.generator = self.generator.cuda()
95 |
96 | def load_model(self, load_filename):
97 | self.generator.load_state_dict(torch.load(load_filename))
98 |
99 | def generate_mc(self, topo, return_loss=False):
100 | emb_topo = self.generator.embedding_topo(topo)
101 | logits = F.relu(self.generator.rc_mc_1(emb_topo))
102 | pred_mc = F.softmax(self.generator.rc_mc_2(logits), dim=-1)
103 | m = Categorical(pred_mc)
104 | mc = m.sample()
105 | loss_mc = torch.log(pred_mc[mc])
106 |
107 | if return_loss:
108 | return mc, loss_mc
109 | else:
110 | return mc
111 |
112 | def generate_ol(self):
113 | num_fail = 0
114 | # generate smiles
115 | valid_smiles = False
116 | while not valid_smiles:
117 | gen_sf = self.generator.evaluate()
118 | try:
119 | gen_sm = sf.decoder(gen_sf[5:-5])
120 | m = Chem.MolFromSmiles(gen_sm)
121 | except Exception:
122 | num_fail += 1
123 | continue
124 |
125 | # encodig_sf
126 | encoded_sf = sf.selfies_to_encoding(selfies=gen_sf,
127 | vocab_stoi=self.vocab_to_idx,
128 | enc_type="label")
129 |
130 | if m and 10 < len(encoded_sf) <= 100:
131 | valid_smiles = True
132 | else:
133 | num_fail += 1
134 |
135 | return encoded_sf, gen_sm, num_fail
136 |
137 | def init_generator(self):
138 | """
139 | Initialize stackRNN
140 | """
141 |
142 | hidden = self.generator.init_hidden()
143 | if self.generator.has_cell:
144 | cell = self.generator.init_cell()
145 | hidden = (hidden, cell)
146 | if self.generator.has_stack:
147 | stack = self.generator.init_stack()
148 | else:
149 | stack = None
150 |
151 | return hidden, stack
152 |
153 | def policy_gradient(self, n_batch, gamma, topo_idx=-1):
154 | """
155 | REINFORCE algorithm
156 | """
157 | self.generator.train()
158 |
159 | rl_loss = 0
160 | self.generator.optimizer.zero_grad()
161 | total_reward = 0
162 | num_fail = 0
163 |
164 | for n_epi in range(n_batch):
165 | # topology
166 | if topo_idx < 0:
167 | topo = torch.randint(0, len(self.topo_to_idx), size=(1,))[0].cuda()
168 | else:
169 | topo = torch.tensor(topo_idx).cuda()
170 |
171 | # metal cluster
172 | mc, loss_mc = self.generate_mc(topo, return_loss=True)
173 | rl_loss -= loss_mc
174 |
175 | # organic linker
176 | encoded_sf, _, n_f = self.generate_ol()
177 | num_fail += n_f
178 |
179 | # rmsd reward
180 | reward_rmsd, output_rmsd = self.get_reward_rmsd(topo, mc, encoded_sf, self.rmsd_predictor)
181 |
182 | # target reward
183 | reward_target, output_target = self.get_reward_target(topo, mc, encoded_sf, self.target_predictor)
184 |
185 | # REINFORCE algorithm
186 | discounted_reward = reward_rmsd + reward_target
187 | total_reward += reward_rmsd + reward_target
188 |
189 | # accumulate trajectory
190 | hidden, stack = self.init_generator()
191 |
192 | # accumulate trajectory
193 | trajectory = torch.LongTensor(encoded_sf).cuda()
194 |
195 | for p in range(len(trajectory) - 1):
196 |
197 | output, hidden, stack = self.generator(trajectory[p], hidden, stack)
198 | log_probs = F.log_softmax(output, dim=-1)
199 | top_i = trajectory[p + 1]
200 | rl_loss -= (log_probs[0, top_i] * discounted_reward)
201 | discounted_reward = discounted_reward * gamma
202 |
203 | # backpropagation
204 | rl_loss = rl_loss / n_batch
205 | total_reward = total_reward / n_batch
206 | num_fail = num_fail / n_batch
207 |
208 | rl_loss.backward()
209 | self.generator.optimizer.step()
210 |
211 | return total_reward, rl_loss.item(), num_fail
212 |
213 | def test_estimate(self, n_to_generate, topo_idx=-1):
214 | with torch.no_grad():
215 | self.generator.evaluate()
216 | gen_mofs = []
217 | rewards = {"rmsd": [], "target": []}
218 | outputs = {"rmsd": [], "target": []}
219 | num_fail = 0
220 |
221 | for i in range(n_to_generate):
222 | # topology
223 | if topo_idx < 0:
224 | topo = torch.randint(0, len(self.topo_to_idx), size=(1,))[0].cuda()
225 | else:
226 | topo = torch.tensor(topo_idx).cuda()
227 |
228 | # metal cluster
229 | mc = self.generate_mc(topo)
230 |
231 | # organic linker
232 | encoded_sf, gen_sm, n_f = self.generate_ol()
233 | num_fail += n_f
234 | gen_mofs.append([topo.detach().item(), mc.detach().item(), gen_sm])
235 |
236 | # rmsd reward
237 | reward_rmsd, output_rmsd = self.get_reward_rmsd(topo, mc, encoded_sf, self.rmsd_predictor)
238 |
239 | # target reward
240 | reward_target, output_target = self.get_reward_target(topo, mc, encoded_sf, self.target_predictor)
241 |
242 | rewards["rmsd"].append(reward_rmsd)
243 | rewards["target"].append(reward_target)
244 |
245 | outputs["rmsd"].append(output_rmsd)
246 | outputs["target"].append(output_target)
247 |
248 | return rewards, outputs, num_fail, gen_mofs
249 |
250 |
251 | def write_logs_test_and_estimate(self, test_reward, test_output, test_num_fail, gen_mofs, n_iter, n_to_generate):
252 | reward_rmsd = np.array(test_reward["rmsd"])
253 | reward_target = np.array(test_reward["target"])
254 | output_rmsd = np.array(test_output["rmsd"])
255 | output_target = np.array(test_output["target"])
256 |
257 | # add scalar to log
258 | self.writer.add_scalar("test/reward/rmsd", reward_rmsd.mean(), n_iter)
259 | self.writer.add_scalar("test/reward/target", reward_target.mean(), n_iter)
260 | self.writer.add_scalar("test/output/rmsd", output_rmsd.mean(), n_iter)
261 | self.writer.add_scalar("test/output/target", output_target.mean(), n_iter)
262 | self.writer.add_scalar("test/num_fail", test_num_fail / n_to_generate, n_iter)
263 | # add histogram to log
264 | fig, axes = plt.subplots(2, 2, figsize=(12, 8))
265 | data = [reward_rmsd, reward_target, output_rmsd, output_target]
266 | title = ["reward_rmsd", "reward_target", "output_rmsd", "output_target"]
267 | for i in range(4):
268 | ax = axes[i // 2, i % 2]
269 | ax.set_title(title[i])
270 | sns.histplot(data[i], ax=ax)
271 | self.writer.add_figure("test", fig, n_iter)
272 |
273 | # add image to log
274 | imgs = []
275 | for i, m in enumerate(gen_mofs):
276 | m = Chem.MolFromSmiles(m[2])
277 | img = MolToImage(m)
278 | img = np.array(img)
279 | img = torch.tensor(img)
280 | imgs.append(img)
281 | imgs = np.stack(imgs, axis=0)
282 | self.writer.add_images("test/gen_ol/", imgs, n_iter, dataformats="NHWC")
283 |
284 |
285 | def train(self,
286 | n_iters=10000,
287 | n_print=100,
288 | n_to_generate=200,
289 | n_batch=10,
290 | gamma=0.80,
291 | topo_idx=-1,
292 | ):
293 | """
294 |
295 | :param n_iters: number of iterations
296 | :param n_print: number of iterations to print
297 | :param n_to_generate: how many mof will be generated when test_and_estimate
298 | :param n_batch: number of batch size
299 | :param gamma: discount ratio of REINFORCE
300 | :param topo_idx: if topo_idx < 0, randomly selecting topology
301 | :return: None
302 | """
303 |
304 | reward = 0
305 | loss = 0
306 | num_fail = 0
307 | callback = 0
308 |
309 | for n_iter in range(n_iters):
310 |
311 | batch_reward, batch_loss, batch_num_fail = self.policy_gradient(n_batch, gamma, topo_idx)
312 | reward += batch_reward
313 | loss += batch_loss
314 | num_fail += batch_num_fail
315 |
316 | if n_iter % n_print == 0:
317 | self.writer.add_scalar("epoch", n_iter, n_iter)
318 | if n_iter != 0:
319 | print(f"########## iteration {n_iter} ##########")
320 | print(
321 | f"train | reward : {reward / n_print:.3f} ,"
322 | f" loss : {loss / n_print:.3f} , "
323 | f"num_fail : {num_fail / n_print:.3f}")
324 | self.writer.add_scalar("train/loss", loss / n_print, n_iter)
325 | self.writer.add_scalar("train/reward", reward / n_print, n_iter)
326 | self.writer.add_scalar("train/num_fail", num_fail / n_print, n_iter)
327 | reward = 0
328 | loss = 0
329 | num_fail = 0
330 |
331 | test_reward, test_output, test_num_fail, gen_mofs = self.test_estimate(n_to_generate, topo_idx)
332 |
333 | test_reward_rmsd = np.mean(test_reward['rmsd'])
334 | test_reward_target = np.mean(test_reward["target"])
335 | total_test_reward = test_reward_rmsd + test_reward_target
336 |
337 | test_output_rmsd = np.mean(test_output['rmsd'])
338 | test_output_target = np.mean(test_output["target"])
339 |
340 | print(
341 | f"test | reward_rmsd : {test_reward_rmsd:.3f} , reward_target : {test_reward_target:.3f}, "
342 | f"output_rmsd : {test_output_rmsd:.3f}, output_target : {test_output_target:.3f},"
343 | f" num_fail : {test_num_fail}/{n_to_generate}")
344 | print(gen_mofs[:5])
345 |
346 | self.write_logs_test_and_estimate(test_reward,
347 | test_output,
348 | test_num_fail,
349 | gen_mofs,
350 | n_iter,
351 | n_to_generate,
352 | )
353 |
354 | if total_test_reward > callback:
355 | path_ckpt = os.path.join(self.log_dir, f"reinforce_model_{n_iter}.ckpt")
356 | torch.save(self.generator.state_dict(), path_ckpt)
357 | print("model save !!!")
358 | callback = total_test_reward
359 |
--------------------------------------------------------------------------------
/mofreinforce/run_generator.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import os
3 |
4 | import pytorch_lightning as pl
5 | from pytorch_lightning.strategies.ddp import DDPStrategy
6 | from pytorch_lightning.loggers import TensorBoardLogger
7 |
8 | from generator.config_generator import ex
9 | from generator.datamodule import GeneratorDatamodule
10 | from generator.module import Generator
11 |
12 |
13 | @ex.automain
14 | def main(_config):
15 | _config = copy.deepcopy(_config)
16 | pl.seed_everything(_config["seed"])
17 |
18 | dm = GeneratorDatamodule(_config)
19 |
20 | model = Generator(_config)
21 |
22 | exp_name = f"{_config['exp_name']}"
23 |
24 | os.makedirs(_config["log_dir"], exist_ok=True)
25 | checkpoint_callback = pl.callbacks.ModelCheckpoint(
26 | save_top_k=1,
27 | verbose=True,
28 | monitor="val/the_metric",
29 | mode="max",
30 | save_last=True,
31 | )
32 |
33 | logger = pl.loggers.TensorBoardLogger(
34 | _config["log_dir"],
35 | name=f'{exp_name}_seed{_config["seed"]}_from_{_config["load_path"].split("/")[-1][:-5]}',
36 | )
37 |
38 | lr_callback = pl.callbacks.LearningRateMonitor(logging_interval="step")
39 | callbacks = [checkpoint_callback, lr_callback]
40 |
41 | num_devices = _config["devices"]
42 | if isinstance(num_devices, list):
43 | num_devices = len(num_devices)
44 |
45 | # gradient accumulation
46 | accumulate_grad_batches = _config["batch_size"] // (
47 | _config["per_gpu_batchsize"] * num_devices * _config["num_nodes"]
48 | )
49 |
50 | trainer = pl.Trainer(
51 | accelerator="gpu",
52 | devices=num_devices,
53 | num_nodes=_config["num_nodes"],
54 | precision=_config["precision"],
55 | deterministic=True,
56 | strategy=DDPStrategy(find_unused_parameters=True),
57 | max_epochs=_config["max_epochs"],
58 | callbacks=callbacks,
59 | logger=logger,
60 | accumulate_grad_batches=accumulate_grad_batches,
61 | log_every_n_steps=10,
62 | resume_from_checkpoint=_config["resume_from"],
63 | val_check_interval=_config["val_check_interval"],
64 | gradient_clip_val=_config["gradient_clip_val"],
65 | # profiler="simple",
66 | )
67 |
68 | if not _config["test_only"]:
69 | trainer.fit(model, datamodule=dm)
70 | trainer.test(model, datamodule=dm)
71 | else:
72 | trainer.test(model, datamodule=dm)
73 |
74 |
75 |
--------------------------------------------------------------------------------
/mofreinforce/run_predictor.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import os
3 |
4 | import pytorch_lightning as pl
5 | from pytorch_lightning.strategies.ddp import DDPStrategy
6 | from pytorch_lightning.loggers import TensorBoardLogger
7 |
8 | from predictor.config_predictor import ex
9 | from predictor.datamodule import Datamodule
10 | from predictor.module import Predictor
11 |
12 |
13 | @ex.automain
14 | def main(_config):
15 | _config = copy.deepcopy(_config)
16 | pl.seed_everything(_config["seed"])
17 |
18 | dm = Datamodule(_config)
19 |
20 | model = Predictor(_config)
21 |
22 | exp_name = f"{_config['exp_name']}"
23 |
24 | os.makedirs(_config["log_dir"], exist_ok=True)
25 | checkpoint_callback = pl.callbacks.ModelCheckpoint(
26 | save_top_k=1,
27 | verbose=True,
28 | monitor="val/the_metric",
29 | mode="max",
30 | save_last=True,
31 | )
32 |
33 | logger = pl.loggers.TensorBoardLogger(
34 | _config["log_dir"],
35 | name=f'{exp_name}_seed{_config["seed"]}_from_{_config["load_path"].split("/")[-1][:-5]}',
36 | )
37 |
38 | lr_callback = pl.callbacks.LearningRateMonitor(logging_interval="step")
39 | callbacks = [checkpoint_callback, lr_callback]
40 |
41 | num_gpus = _config["devices"]
42 | if isinstance(num_gpus, list):
43 | num_gpus = len(num_gpus)
44 |
45 | # gradient accumulation
46 | accumulate_grad_batches = _config["batch_size"] // (
47 | _config["per_gpu_batchsize"] * num_gpus * _config["num_nodes"]
48 | )
49 |
50 | trainer = pl.Trainer(
51 | gpus=num_gpus,
52 | num_nodes=_config["num_nodes"],
53 | precision=_config["precision"],
54 | deterministic=True,
55 | strategy=DDPStrategy(find_unused_parameters=True),
56 | max_epochs=_config["max_epochs"],
57 | callbacks=callbacks,
58 | logger=logger,
59 | accumulate_grad_batches=accumulate_grad_batches,
60 | log_every_n_steps=10,
61 | resume_from_checkpoint=_config["resume_from"],
62 | val_check_interval=_config["val_check_interval"],
63 | # profiler="simple",
64 | )
65 |
66 | if not _config["test_only"]:
67 | trainer.fit(model, datamodule=dm)
68 | trainer.test(model, datamodule=dm)
69 | else:
70 | trainer.test(model, datamodule=dm)
71 |
--------------------------------------------------------------------------------
/mofreinforce/run_reinforce.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import os
3 |
4 | import pytorch_lightning as pl
5 | from pytorch_lightning.strategies.ddp import DDPStrategy
6 | from pytorch_lightning.loggers import TensorBoardLogger
7 |
8 | from reinforce.config_reinforce import ex
9 |
10 | from predictor.config_predictor import config as predictor_config
11 | from predictor.config_predictor import _loss_names
12 | from predictor.module import Predictor
13 |
14 | from generator.config_generator import config as generator_config
15 | from generator.datamodule import GeneratorDatamodule
16 | from generator.module import Generator
17 |
18 | from reinforce.module import Reinforce
19 |
20 | @ex.automain
21 | def main(_config):
22 | pl.seed_everything(_config["seed"])
23 | # 1. load predictor
24 | predictors = []
25 | for i in range(len(_config["predictor_load_path"])):
26 | pred_config = predictor_config()
27 | pred_config["test_only"] = True
28 | pred_config["loss_names"] = _loss_names({"regression": 1})
29 | pred_config["load_path"] = _config["predictor_load_path"][i]
30 | pred_config["mean"] = _config["mean"][i]
31 | pred_config["std"] = _config["std"][i]
32 |
33 | predictor = Predictor(pred_config)
34 | predictor.eval()
35 | predictors.append(predictor)
36 |
37 | # 2. load generator
38 | gen_config = generator_config()
39 | gen_config["load_path"] = _config["generator_load_path"]
40 | gen_config["batch_size"] = _config["batch_size"]
41 | gen_config["dataset_dir"] = _config["dataset_dir"]
42 |
43 | generator = Generator(gen_config)
44 | dm = GeneratorDatamodule(gen_config)
45 |
46 | # 3. set reinforce
47 | _config = copy.deepcopy(_config)
48 | pl.seed_everything(_config["seed"])
49 |
50 | model = Reinforce(generator, predictors, _config)
51 |
52 | exp_name = f"{_config['exp_name']}"
53 | os.makedirs(_config["log_dir"], exist_ok=True)
54 | checkpoint_callback = pl.callbacks.ModelCheckpoint(
55 | save_top_k=-1,
56 | verbose=True,
57 | monitor="val/total_reward",
58 | mode="max",
59 | save_last=True,
60 | )
61 |
62 | logger = pl.loggers.TensorBoardLogger(
63 | _config["log_dir"],
64 | name=f'{exp_name}_seed{_config["seed"]}_from_{_config["load_path"].split("/")[-1][:-5]}',
65 | )
66 | lr_callback = pl.callbacks.LearningRateMonitor(logging_interval="step")
67 | callbacks = [checkpoint_callback, lr_callback]
68 |
69 | trainer = pl.Trainer(
70 | accelerator="gpu",
71 | devices=_config["devices"],
72 | num_nodes=_config["num_nodes"],
73 | precision=_config["precision"],
74 | deterministic=True,
75 | strategy=DDPStrategy(find_unused_parameters=True),
76 | max_epochs=_config["max_epochs"],
77 | callbacks=callbacks,
78 | logger=logger,
79 | accumulate_grad_batches=_config["accumulate_grad_batches"],
80 | log_every_n_steps=10,
81 | resume_from_checkpoint=_config["resume_from"],
82 | limit_train_batches=_config["limit_train_batches"],
83 | limit_val_batches=_config["limit_val_batches"],
84 | num_sanity_val_steps=_config["limit_val_batches"],
85 | val_check_interval=_config["val_check_interval"],
86 | gradient_clip_val=_config["gradient_clip_val"],
87 | )
88 |
89 | if not _config["test_only"]:
90 | trainer.fit(model, datamodule=dm)
91 | # trainer.test(model, datamodule=dm)
92 | else:
93 | trainer.test(model, datamodule=dm)
94 |
95 |
96 |
--------------------------------------------------------------------------------
/mofreinforce/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hspark1212/MOFreinforce/b920084a29db9223482ddfeffd2e775b82f49e63/mofreinforce/utils/__init__.py
--------------------------------------------------------------------------------
/mofreinforce/utils/download.py:
--------------------------------------------------------------------------------
1 | import os
2 | import wget
3 | import tarfile
4 | from pathlib import Path
5 | from mofreinforce import __root_dir__
6 |
7 |
8 | DEFAULT_PATH = Path(__root_dir__)
9 |
10 |
11 | class DownloadError(Exception):
12 | pass
13 |
14 |
15 | def _remove_tmp_file(direc: Path):
16 | tmp_list = direc.parent.glob("*.tmp")
17 | for tmp in tmp_list:
18 | if tmp.exists():
19 | os.remove(tmp)
20 |
21 |
22 | def _download_file(link, direc, name="target"):
23 | if direc.exists():
24 | print(f"{name} already exists.")
25 | return
26 | try:
27 | print(f"\n====Download {name} =============================================\n")
28 | filename = wget.download(link, out=str(direc))
29 | except KeyboardInterrupt:
30 | _remove_tmp_file(direc)
31 | raise
32 | except Exception as e:
33 | _remove_tmp_file(direc)
34 | raise DownloadError(e)
35 | else:
36 | print(
37 | f"\n====Successfully download : {filename}=======================================\n"
38 | )
39 |
40 |
41 | def download_default(direc=None, remove_tarfile=False):
42 | """
43 | downlaod data and pre-trained models including a generator, predictors for DAC
44 | """
45 | if not direc:
46 | direc = Path(DEFAULT_PATH)
47 | if not direc.exists():
48 | direc.mkdir(parents=True, exist_ok=True)
49 | direc = direc / "default.tar.gz"
50 | else:
51 | direc = Path(direc)
52 | if direc.is_dir():
53 | if not direc.exists():
54 | direc.mkdir(parents=True, exist_ok=True)
55 | direc = direc / "default.tar.gz"
56 | else:
57 | raise ValueError(f"direc must be path for directory, not {direc}")
58 |
59 | link = "https://figshare.com/ndownloader/files/40215109"
60 | name = "basic data and pretrained models"
61 | _download_file(link, direc, name)
62 |
63 | print(f"\n====Unzip : {name}===============================================\n")
64 | with tarfile.open(direc) as f:
65 | f.extractall(path=direc.parent)
66 |
67 | print(
68 | f"\n====Unzip successfully: {name}===============================================\n"
69 | )
70 |
71 | if remove_tarfile:
72 | os.remove(direc)
73 |
--------------------------------------------------------------------------------
/mofreinforce/utils/gadgets.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torchmetrics import Metric
3 |
4 |
5 | class Accuracy(Metric):
6 | def __init__(self, dist_sync_on_step=False):
7 | super().__init__(dist_sync_on_step=dist_sync_on_step)
8 | self.add_state("correct", default=torch.tensor(0.0), dist_reduce_fx="sum")
9 | self.add_state("total", default=torch.tensor(0.0), dist_reduce_fx="sum")
10 |
11 | def update(self, logits, target):
12 | logits, target = (
13 | logits.detach().to(self.correct.device),
14 | target.detach().to(self.correct.device),
15 | )
16 | if len(logits.shape) > 1:
17 | preds = logits.argmax(dim=-1)
18 | else:
19 | # binary accuracy
20 | logits[logits >= 0.5] = 1
21 | logits[logits < 0.5] = 0
22 | preds = logits
23 |
24 | preds = preds[target != 0] # pad_idx
25 | target = target[target != 0] # pad_idx
26 |
27 | if target.numel() == 0:
28 | return 1
29 |
30 | assert preds.shape == target.shape
31 |
32 | self.correct += torch.sum(preds == target)
33 | self.total += target.numel()
34 |
35 | def compute(self):
36 | return self.correct / self.total
37 |
38 |
39 | class Scalar(Metric):
40 | def __init__(self, dist_sync_on_step=False):
41 | super().__init__(dist_sync_on_step=dist_sync_on_step)
42 | self.add_state("scalar", default=torch.tensor(0.0), dist_reduce_fx="sum")
43 | self.add_state("total", default=torch.tensor(0.0), dist_reduce_fx="sum")
44 |
45 | def update(self, scalar):
46 | if isinstance(scalar, torch.Tensor):
47 | scalar = scalar.detach().to(self.scalar.device)
48 | else:
49 | scalar = torch.tensor(scalar).float().to(self.scalar.device)
50 | self.scalar += scalar
51 | self.total += 1
52 |
53 | def compute(self):
54 | return self.scalar / self.total
--------------------------------------------------------------------------------
/mofreinforce/utils/metrics.py:
--------------------------------------------------------------------------------
1 | import json
2 | import torch
3 | from rdkit import Chem
4 | from rdkit import RDLogger
5 |
6 | import libs.selfies as sf
7 |
8 | RDLogger.DisableLog('rdApp.*')
9 |
10 | topo_to_cn = json.load(open("data/final_topo_cn.json"))
11 | mc_to_cn = json.load(open("data/mc_cn.json"))
12 |
13 |
14 | class Metrics():
15 | def __init__(self, vocab_to_idx, idx_to_vocab):
16 | # generator
17 | self.num_fail = []
18 | self.conn_match = []
19 | self.scaffold = []
20 | # collect
21 | self.input_frags = []
22 | self.gen_ol = []
23 | self.gen_topo = []
24 | self.gen_mc = []
25 | # reinforce
26 | self.rewards = []
27 | self.preds = []
28 | # vocab
29 | self.topo_to_cn = topo_to_cn
30 | self.mc_to_cn = mc_to_cn
31 | self.vocab_to_idx = vocab_to_idx
32 | self.idx_to_vocab = idx_to_vocab
33 |
34 | def update(self, output, src, rewards=None, preds=None):
35 | """
36 | (1) metric for connection matching
37 | (2) metric for unique organic linker
38 | (3) metric for unique topology and metal cluster
39 | (4) metric for scaffold
40 | (5) (optional) rewards
41 | (6) (optional) preds
42 |
43 | :param output:
44 | :return:
45 | """
46 | topo_cn = topo_to_cn.get(output["topo"], [0]) # if topo is [PAD] then topo_cn = [0]
47 | if len(topo_cn) == 1:
48 | topo_cn.append(2)
49 | mc_cn = mc_to_cn.get(output["mc"], -1) # if mc is [PAD] then mc_cn = -1
50 | ol_cn = output["ol_idx"].count(self.vocab_to_idx["[*]"])
51 |
52 | # (1) metric for connection matching
53 | if set(topo_cn) == {mc_cn, ol_cn}:
54 | self.conn_match.append(1)
55 | else:
56 | self.conn_match.append(0)
57 |
58 | # (2) metric for unique organic linker
59 | # (3) metric for unique topology and metal cluster
60 | gen_sm = output["gen_sm"]
61 | self.gen_ol.append(gen_sm)
62 | self.gen_topo.append(output["topo"])
63 | self.gen_mc.append(output["mc"])
64 |
65 | # (4) metric for scaffold
66 | # get frags
67 | frags = [self.idx_to_vocab[idx.item()] for idx in src.squeeze(0)[3:]]
68 | frags = frags[:frags.index("[EOS]")]
69 | frags = "".join(frags)
70 |
71 | frags_sm = [sf.decoder(f) for f in frags.split(".")]
72 | self.input_frags.append(frags_sm)
73 | # replace * with H
74 | du, hy = Chem.MolFromSmiles('*'), Chem.MolFromSmiles('[H]')
75 | try:
76 | m = Chem.MolFromSmiles(gen_sm)
77 | check_gen_sm = Chem.ReplaceSubstructs(m, du, hy, replaceAll=True)[0]
78 |
79 | check_ = True
80 | for sm in frags_sm:
81 | if not check_gen_sm.HasSubstructMatch(Chem.MolFromSmiles(sm)):
82 | check_ = False
83 | if check_:
84 | self.scaffold.append(1)
85 | else:
86 | self.scaffold.append(0)
87 | except:
88 | pass
89 |
90 | # (5) metric for reward
91 | self.rewards.append(rewards)
92 | self.preds.append(preds)
93 |
94 | @staticmethod
95 | def get_mean(list_):
96 | return torch.mean(torch.Tensor(list_))
97 |
--------------------------------------------------------------------------------
/mofreinforce/utils/module_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from transformers import AdamW
4 | from transformers import (
5 | get_polynomial_decay_schedule_with_warmup,
6 | get_cosine_schedule_with_warmup,
7 | get_constant_schedule,
8 | get_constant_schedule_with_warmup,
9 | )
10 | from utils.gadgets import Scalar, Accuracy
11 |
12 |
13 | def init_weights(module):
14 | if isinstance(module, (nn.Linear, nn.Embedding)):
15 | module.weight.data.normal_(mean=0.0, std=0.02)
16 | elif isinstance(module, nn.LayerNorm):
17 | module.bias.data.zero_()
18 | module.weight.data.fill_(1.0)
19 |
20 | if isinstance(module, nn.Linear) and module.bias is not None:
21 | module.bias.data.zero_()
22 |
23 |
24 | def set_metrics(pl_module):
25 | for split in ["train", "val"]:
26 | for k, v in pl_module.hparams.config["loss_names"].items():
27 | if v < 1:
28 | continue
29 | if k == "regression" or k == "vfr":
30 | setattr(pl_module, f"{split}_{k}_loss", Scalar())
31 | setattr(pl_module, f"{split}_{k}_mae", Scalar())
32 | setattr(pl_module, f"{split}_{k}_r2", Scalar())
33 | elif k == "generator":
34 | setattr(pl_module, f"{split}_{k}_loss", Scalar())
35 | setattr(pl_module, f"{split}_{k}_acc_topo", Scalar())
36 | setattr(pl_module, f"{split}_{k}_acc_mc", Scalar())
37 | setattr(pl_module, f"{split}_{k}_acc_ol", Scalar())
38 | else:
39 | setattr(pl_module, f"{split}_{k}_accuracy", Accuracy())
40 | setattr(pl_module, f"{split}_{k}_loss", Scalar())
41 |
42 |
43 | def set_task(pl_module):
44 | pl_module.current_tasks = [
45 | k for k, v in pl_module.hparams.config["loss_names"].items() if v >= 1
46 | ]
47 | return
48 |
49 |
50 | def epoch_wrapup(pl_module):
51 | phase = "train" if pl_module.training else "val"
52 |
53 | the_metric = 0
54 | for loss_name, v in pl_module.hparams.config["loss_names"].items():
55 | if v < 1:
56 | continue
57 |
58 | if loss_name == "regression" or loss_name == "vfr":
59 | # mse loss
60 | pl_module.log(
61 | f"{loss_name}/{phase}/loss_epoch",
62 | getattr(pl_module, f"{phase}_{loss_name}_loss").compute(),
63 | )
64 | getattr(pl_module, f"{phase}_{loss_name}_loss").reset()
65 | # mae loss
66 | value = getattr(pl_module, f"{phase}_{loss_name}_mae").compute()
67 | pl_module.log(
68 | f"{loss_name}/{phase}/mae_epoch",
69 | value,
70 | )
71 | getattr(pl_module, f"{phase}_{loss_name}_mae").reset()
72 |
73 | value = -value
74 | elif loss_name == "generator":
75 | acc_topo = getattr(pl_module, f"{phase}_{loss_name}_acc_topo").compute()
76 | pl_module.log(f"{loss_name}/{phase}/acc_topo", acc_topo)
77 | getattr(pl_module, f"{phase}_{loss_name}_acc_topo").reset()
78 |
79 | acc_mc = getattr(pl_module, f"{phase}_{loss_name}_acc_mc").compute()
80 | pl_module.log(f"{loss_name}/{phase}/acc_mc", acc_mc)
81 | getattr(pl_module, f"{phase}_{loss_name}_acc_mc").reset()
82 |
83 | acc_ol = getattr(pl_module, f"{phase}_{loss_name}_acc_ol").compute()
84 | pl_module.log(f"{loss_name}/{phase}/acc_ol", acc_ol)
85 | getattr(pl_module, f"{phase}_{loss_name}_acc_ol").reset()
86 | value = acc_topo + acc_mc + acc_ol
87 |
88 | else:
89 | value = getattr(pl_module, f"{phase}_{loss_name}_accuracy").compute()
90 | pl_module.log(f"{loss_name}/{phase}/accuracy_epoch", value)
91 | getattr(pl_module, f"{phase}_{loss_name}_accuracy").reset()
92 | pl_module.log(
93 | f"{loss_name}/{phase}/loss_epoch",
94 | getattr(pl_module, f"{phase}_{loss_name}_loss").compute(),
95 | )
96 | getattr(pl_module, f"{phase}_{loss_name}_loss").reset()
97 |
98 | the_metric += value
99 |
100 | pl_module.log(f"{phase}/the_metric", the_metric)
101 |
102 |
103 | def set_schedule(pl_module):
104 | lr = pl_module.hparams.config["learning_rate"]
105 | wd = pl_module.hparams.config["weight_decay"]
106 |
107 | no_decay = [
108 | "bias",
109 | "LayerNorm.bias",
110 | "LayerNorm.weight",
111 | "norm.bias",
112 | "norm.weight",
113 | "norm1.bias",
114 | "norm1.weight",
115 | "norm2.bias",
116 | "norm2.weight",
117 | ]
118 | head_names = ["regression_head", "classification_head"]
119 | lr_mult = pl_module.hparams.config["lr_mult"]
120 | end_lr = pl_module.hparams.config["end_lr"]
121 | decay_power = pl_module.hparams.config["decay_power"]
122 | optim_type = pl_module.hparams.config["optim_type"]
123 |
124 | optimizer_grouped_parameters = [
125 | {
126 | "params": [
127 | p
128 | for n, p in pl_module.named_parameters()
129 | if not any(nd in n for nd in no_decay) # not within no_decay
130 | and not any(bb in n for bb in head_names) # not within head_names
131 | ],
132 | "weight_decay": wd,
133 | "lr": lr,
134 | },
135 | {
136 | "params": [
137 | p
138 | for n, p in pl_module.named_parameters()
139 | if any(nd in n for nd in no_decay) # within no_decay
140 | and not any(bb in n for bb in head_names) # not within head_names
141 | ],
142 | "weight_decay": 0.0,
143 | "lr": lr,
144 | },
145 | {
146 | "params": [
147 | p
148 | for n, p in pl_module.named_parameters()
149 | if not any(nd in n for nd in no_decay) # not within no_decay
150 | and any(bb in n for bb in head_names) # within head_names
151 | ],
152 | "weight_decay": wd,
153 | "lr": lr * lr_mult,
154 | },
155 | {
156 | "params": [
157 | p
158 | for n, p in pl_module.named_parameters()
159 | if any(nd in n for nd in no_decay) and any(bb in n for bb in head_names)
160 | # within no_decay and head_names
161 | ],
162 | "weight_decay": 0.0,
163 | "lr": lr * lr_mult,
164 | },
165 | ]
166 |
167 | if optim_type == "adamw":
168 | optimizer = AdamW(optimizer_grouped_parameters, lr=lr, eps=1e-8, betas=(0.9, 0.98)
169 | )
170 | elif optim_type == "adam":
171 | optimizer = torch.optim.Adam(optimizer_grouped_parameters, lr=lr)
172 | elif optim_type == "sgd":
173 | optimizer = torch.optim.SGD(optimizer_grouped_parameters, lr=lr, momentum=0.9)
174 |
175 | if pl_module.trainer.max_steps == -1:
176 | max_steps = (
177 | len(pl_module.trainer.datamodule.train_dataloader())
178 | * pl_module.trainer.max_epochs
179 | // pl_module.trainer.accumulate_grad_batches
180 | // (pl_module.trainer.num_devices * pl_module.trainer.num_nodes)
181 | )
182 | else:
183 | max_steps = pl_module.trainer.max_steps
184 |
185 | warmup_steps = pl_module.hparams.config["warmup_steps"]
186 | if isinstance(pl_module.hparams.config["warmup_steps"], float):
187 | warmup_steps = int(max_steps * warmup_steps)
188 |
189 | print(f"max_epochs: {pl_module.trainer.max_epochs} | max_steps: {max_steps} | warmup_steps : {warmup_steps} "
190 | f"| weight_decay : {wd} | decay_power : {decay_power}")
191 |
192 | if decay_power == "cosine":
193 | scheduler = get_cosine_schedule_with_warmup(
194 | optimizer,
195 | num_warmup_steps=warmup_steps,
196 | num_training_steps=max_steps,
197 | )
198 | elif decay_power == 'constant':
199 | scheduler = get_constant_schedule(
200 | optimizer,
201 | )
202 | elif decay_power == 'constant_with_warmup':
203 | scheduler = get_constant_schedule_with_warmup(
204 | optimizer,
205 | num_warmup_steps=warmup_steps,
206 | )
207 | else:
208 | scheduler = get_polynomial_decay_schedule_with_warmup(
209 | optimizer,
210 | num_warmup_steps=warmup_steps,
211 | num_training_steps=max_steps,
212 | lr_end=end_lr,
213 | power=decay_power,
214 | )
215 |
216 | sched = {"scheduler": scheduler, "interval": "step"}
217 |
218 | return (
219 | [optimizer],
220 | [sched],
221 | )
222 |
223 |
224 | class Normalizer(object):
225 | """
226 | normalize for regression
227 | """
228 |
229 | def __init__(self, mean, std):
230 | if mean and std:
231 | self._norm_func = lambda tensor: (tensor - mean) / std
232 | self._denorm_func = lambda tensor: tensor * std + mean
233 | else:
234 | self._norm_func = lambda tensor: tensor
235 | self._denorm_func = lambda tensor: tensor
236 |
237 | self.mean = mean
238 | self.std = std
239 |
240 | def encode(self, tensor):
241 | return self._norm_func(tensor)
242 |
243 | def decode(self, tensor):
244 | return self._denorm_func(tensor)
245 |
--------------------------------------------------------------------------------
/predictor.md:
--------------------------------------------------------------------------------
1 | # Predictor
2 | ## 1. Prepare dataset
3 |
4 | Once you download default data by running the following command,
5 | ```angular2html
6 | $ mofreinforce download default
7 | ```
8 |
9 | Then, examples of dataset of predictors will be downloaded at `data/dataset_predictor/qkh` or `data/dataset_predictor/selectivity`
10 |
11 | The dataset directory should include `train.json`, `val.json` and `test.json`.
12 |
13 | The json consists of names of structures (key) and dictionary of descriptions of structures (values).
14 | The descriptions include `topo_name`, `mc_name`, `ol_name`, `ol_selfies`, `topo`, `mc`, `ol`, `target`.
15 |
16 | (optional)
17 | - `topo_name` : (string) name of topology.
18 | - `mc_name` : (string) name of metal cluster.
19 | - `ol_name` : (string) name of organic linker.
20 | (required)
21 | The topologies, metal clusters and organic linkers should be vectorized.
22 | Given topologies and metal clusters are categorical variables, they need to be converted to idx.
23 | - `topo` : (int) index of topology. The index can be found in `data/mc_to_idx.json`
24 | - `mc` : (int) index of metal cluster. The index can be found in `data/topo_to_idx.json`
25 | When it comes to organic linkers, it is represented by SELFIES.
26 | - `ol_selfies` : (string) SELFIES string of organic linker. The smiles can be converted into SELFIES using `sf.decode(SMILES)` in `libs/selfies`.
27 | - `ol` : (list) a list of index of SELFIES string. The index of vocabulary of SELFIES can be found in `data/vocab_to_idx.json`
28 | Finally, the target property you want to optimize should be defined.
29 | - `target` : (float) target property.
30 |
31 | ## 2. Training predictor
32 |
33 | Here is an example to train predictor for heat of adsorption in the `mofreinforce` directory
34 | ```angular2html
35 | # python run_predictor.py with regression_qkh_round3
36 | ```
37 | By modifying `predictor/config_predictor.py`, you can train your predictors.
38 |
39 |
40 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | pytorch-lightning==1.7.7
2 | torch<2.0.0
3 | torchmetrics
4 | sacred
5 | numpy
6 | scikit-learn
7 | tqdm
8 | timm
9 | transformers
10 | SmilesPE
11 | rdkit
12 | wget
13 | # for tutorial.ipynb
14 | pandas
15 | jupyterlab
16 | matplotlib
17 | seaborn
18 | ase
19 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import re
2 | from setuptools import setup, find_packages
3 |
4 | try:
5 | import torch
6 | except ImportError:
7 | raise EnvironmentError('Torch must be installed before installation')
8 |
9 | with open("requirements.txt", "r") as f:
10 | install_requires = f.readlines()
11 |
12 | with open("README.md", "r") as f:
13 | long_description=f.read()
14 |
15 | extras_require = {
16 | 'docs': ['sphinx', 'livereload', 'myst-parser']
17 | }
18 |
19 | # with open('mofreinforce/__init__.py') as f:
20 | # version = re.search(r"__version__ = '(?P.+)'", f.read()).group('version')
21 |
22 |
23 | setup(
24 | name='mofreinforce',
25 | version="0.0.1",
26 | description='mofreinforce',
27 | long_description=long_description,
28 | long_description_content_type='text/markdown',
29 | author='Hyunsoo Park',
30 | author_email='phs68660888@gmail.com',
31 | packages=find_packages(),
32 | package_data={'mofreinforce': []},
33 | install_requires=install_requires,
34 | extras_require=extras_require,
35 | scripts=[],
36 | download_url='https://github.com/hspark1212/MOFreinforce',
37 | entry_points={'console_scripts':['mofreinforce=mofreinforce.cli.main:main']},
38 | python_requires='>=3.8',
39 | )
--------------------------------------------------------------------------------