├── laiddmg ├── models │ ├── __init__.py │ ├── vae │ │ ├── __init__.py │ │ ├── configuration.py │ │ ├── vae_trainer.py │ │ └── modeling.py │ └── char_rnn │ │ ├── __init__.py │ │ ├── configuration.py │ │ ├── modeling.py │ │ └── char_rnn_trainer.py ├── generate.sh ├── train.vae.sh ├── configuration_utils.py ├── train.char_rnn.sh ├── __init__.py ├── train.py ├── modeling_utils.py ├── datasets.py ├── generate.py ├── logging_utils.py ├── trainer.py ├── common_parser.py ├── utils.py ├── tokenization_utils.py ├── jupyter_char_rnn.ipynb └── jupyter_vae.ipynb ├── datasets └── moses │ ├── test.csv.gz │ └── train.csv.gz ├── LICENSE ├── .pre-commit-config.yaml ├── setup.py ├── .gitignore └── README.md /laiddmg/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /laiddmg/models/vae/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /laiddmg/models/char_rnn/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /datasets/moses/test.csv.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ilguyi/LAIDD-molecular-generation/HEAD/datasets/moses/test.csv.gz -------------------------------------------------------------------------------- /datasets/moses/train.csv.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ilguyi/LAIDD-molecular-generation/HEAD/datasets/moses/train.csv.gz -------------------------------------------------------------------------------- /laiddmg/generate.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | model=char_rnn 4 | # model=vae 5 | # echo $model 6 | 7 | laiddmg-generate $model --seed 219 \ 8 | --checkpoint_dir ./outputs/$model/exp1 \ 9 | --weights_name ckpt_010.pt \ 10 | --num_generation 10000 \ 11 | --batch_size_for_generation 256 12 | -------------------------------------------------------------------------------- /laiddmg/train.vae.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export CUDA_VISIBLE_DEVICES=0 4 | 5 | model=vae 6 | echo $model 7 | 8 | laiddmg-train $model --seed 219 \ 9 | --output_dir exp1 \ 10 | --dataset_path ../datasets \ 11 | --log_steps 10 \ 12 | --num_train_epochs 100 \ 13 | --train_batch_size 512 \ 14 | --lr 1e-4 15 | -------------------------------------------------------------------------------- /laiddmg/configuration_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from typing import Union 4 | 5 | 6 | class ModelConfig: 7 | 8 | model_type: str = '' 9 | 10 | def __init__(self, **kwargs): 11 | pass 12 | 13 | @classmethod 14 | def from_pretrained(cls, json_file: Union[str, os.PathLike]) -> 'ModelConfig': 15 | with open(json_file, 'r', encoding='utf-8') as reader: 16 | text = reader.read() 17 | config_dict = json.loads(text) 18 | 19 | return cls(**config_dict) 20 | -------------------------------------------------------------------------------- /laiddmg/train.char_rnn.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export CUDA_VISIBLE_DEVICES=0 4 | 5 | model=char_rnn 6 | echo $model 7 | 8 | laiddmg-train $model --seed 219 \ 9 | --output_dir exp1 \ 10 | --dataset_path ../datasets \ 11 | --log_steps 10 \ 12 | --num_train_epochs 100 \ 13 | --train_batch_size 256 \ 14 | --lr 1e-3 \ 15 | --step_size 10 \ 16 | --gamma 0.5 17 | 18 | #--train_batch_size 128 \ 19 | -------------------------------------------------------------------------------- /laiddmg/models/char_rnn/configuration.py: -------------------------------------------------------------------------------- 1 | from ...configuration_utils import ModelConfig 2 | from ... import logging_utils 3 | 4 | 5 | logger = logging_utils.get_logger(__name__) 6 | 7 | 8 | class CharRNNConfig(ModelConfig): 9 | 10 | model_type: str = 'char_rnn' 11 | 12 | def __init__( 13 | self, 14 | tokenizer: str = 'moses', 15 | vocab_size: int = 30, 16 | embedding_dim: int = 32, 17 | hidden_dim: int = 768, 18 | num_layers: int = 3, 19 | dropout: float = 0.2, 20 | padding_value: int = 0, 21 | **kwargs, 22 | ): 23 | 24 | self.tokenizer = tokenizer 25 | self.vocab_size = vocab_size 26 | self.embedding_dim = embedding_dim 27 | self.hidden_dim = hidden_dim 28 | self.num_layers = num_layers 29 | self.dropout = dropout 30 | self.padding_value = padding_value 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Il Gu Yi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /laiddmg/models/vae/configuration.py: -------------------------------------------------------------------------------- 1 | from ...configuration_utils import ModelConfig 2 | from ... import logging_utils 3 | 4 | 5 | logger = logging_utils.get_logger(__name__) 6 | 7 | 8 | class VAEConfig(ModelConfig): 9 | 10 | model_type: str = 'vae' 11 | 12 | def __init__( 13 | self, 14 | tokenizer: str = 'moses', 15 | vocab_size: int = 30, 16 | embedding_dim: int = 30, 17 | encoder_hidden_dim: int = 256, 18 | encoder_num_layers: int = 1, 19 | encoder_dropout: float = 0.5, 20 | latent_dim: int = 128, 21 | decoder_hidden_dim: int = 512, 22 | decoder_num_layers: int = 3, 23 | decoder_dropout: float = 0.0, 24 | padding_value: int = 0, 25 | **kwargs, 26 | ): 27 | 28 | self.tokenizer = tokenizer 29 | self.vocab_size = vocab_size 30 | self.embedding_dim = embedding_dim 31 | self.encoder_hidden_dim = encoder_hidden_dim 32 | self.encoder_num_layers = encoder_num_layers 33 | self.encoder_dropout = encoder_dropout 34 | self.latent_dim = latent_dim 35 | self.decoder_hidden_dim = decoder_hidden_dim 36 | self.decoder_num_layers = decoder_num_layers 37 | self.decoder_dropout = decoder_dropout 38 | self.padding_value = padding_value 39 | -------------------------------------------------------------------------------- /laiddmg/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore F401 imported but unused warnings in this module 3 | # but to preserve other warnings. So, don't check this module at all. 4 | 5 | 6 | __version__ = '0.0.1' 7 | 8 | 9 | from collections import OrderedDict 10 | 11 | # Parser 12 | from .common_parser import ( 13 | get_train_args, 14 | get_generate_args, 15 | ) 16 | 17 | # Configs 18 | from .models.char_rnn.configuration import CharRNNConfig 19 | from .models.vae.configuration import VAEConfig 20 | 21 | # Tokenizer 22 | from .tokenization_utils import Tokenizer 23 | 24 | # Models 25 | from .models.char_rnn.modeling import CharRNNModel 26 | from .models.vae.modeling import VAEModel 27 | 28 | # Trainers 29 | from .models.char_rnn.char_rnn_trainer import CharRNNTrainer 30 | from .models.vae.vae_trainer import VAETrainer 31 | 32 | # Datasets 33 | from .datasets import get_rawdataset 34 | from .datasets import get_dataset 35 | 36 | # Utils 37 | from .utils import ( 38 | set_output_dir, 39 | set_output_dir_for_generation, 40 | get_batch_size_list_for_generate, 41 | measure_duration_time, 42 | ) 43 | 44 | # Logging utils 45 | from . import logging_utils 46 | 47 | 48 | TRAINER_MAPPING = OrderedDict( 49 | [ 50 | ('char_rnn', CharRNNTrainer), 51 | ('vae', VAETrainer), 52 | ] 53 | ) 54 | -------------------------------------------------------------------------------- /laiddmg/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env/python 2 | 3 | import os 4 | import sys 5 | 6 | from datetime import datetime 7 | 8 | from laiddmg import ( 9 | get_train_args, 10 | set_output_dir, 11 | Tokenizer, 12 | CharRNNConfig, 13 | CharRNNModel, 14 | VAEConfig, 15 | VAEModel, 16 | get_rawdataset, 17 | get_dataset, 18 | TRAINER_MAPPING, 19 | measure_duration_time, 20 | ) 21 | 22 | from . import logging_utils 23 | 24 | 25 | def main(): 26 | 27 | start_time = datetime.now() 28 | # get training args 29 | args = get_train_args() 30 | model_type = sys.argv[1] 31 | args = set_output_dir(model_type, args) 32 | logger = logging_utils.get_logger(__name__, os.path.join(args.output_dir, 'out.log')) 33 | 34 | logger.info(args) 35 | logger.info(f'model type: {model_type}') 36 | logger.info(f'use device: {args.device}') 37 | 38 | assert model_type in ['char_rnn', 'vae'] 39 | tokenizer = Tokenizer() 40 | if model_type == 'char_rnn': 41 | config = CharRNNConfig() 42 | model = CharRNNModel(config) 43 | else: 44 | config = VAEConfig() 45 | model = VAEModel(config) 46 | 47 | print(config) 48 | print(tokenizer('c1ccccc1')) 49 | print(model) 50 | 51 | print(model.device) 52 | print(model.dtype) 53 | print(model.num_parameters()) 54 | 55 | # get raw dataset (SMILES) 56 | train = get_rawdataset('train') 57 | 58 | # get PyTorch Dataset 59 | train_dataset = get_dataset(train, tokenizer) 60 | 61 | # get trainer 62 | trainer = TRAINER_MAPPING[model_type] 63 | t = trainer(model=model, 64 | args=args, 65 | train_dataset=train_dataset, 66 | tokenizer=tokenizer) 67 | t.train() 68 | 69 | end_time = datetime.now() 70 | measure_duration_time(end_time - start_time) 71 | 72 | 73 | if __name__ == '__main__': 74 | main() 75 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # @see https://pre-commit.com/hooks.html 2 | # @see https://github.com/readthedocs/common/blob/a7af72b7e5946d62fd3b4dc524d6406604488f6f/pre-commit-config.yaml 3 | # @see https://github.com/google/yapf 4 | 5 | repos: 6 | 7 | - repo: https://github.com/pre-commit/pre-commit-hooks.git 8 | rev: v2.3.0 9 | hooks: 10 | # prevent giant files from being committed 11 | - id: check-added-large-files 12 | args: ['--maxkb=5120'] 13 | # check for UTF-8 byte-order marker 14 | - id: check-byte-order-marker 15 | # check for files that could conflict in case-insensitive filesystems 16 | - id: check-case-conflict 17 | # check for debugger imports and Python 3.7 or above breakpoint() calls 18 | - id: debug-statements 19 | # replace double quoted strings with single quoted strings 20 | - id: double-quote-string-fixer 21 | # ensure that a file is either empty or ends with single new line 22 | - id: end-of-file-fixer 23 | exclude: 'static' 24 | # add '# -*- config: utf-8 -*-' at the top 25 | - id: fix-encoding-pragma 26 | args: ['--remove'] # remove the encoding pragma 27 | # check for merge conflict strings 28 | - id: check-merge-conflict 29 | # check for symlinks pointing to nothing 30 | - id: check-symlinks 31 | # trim trailing whitespaces 32 | - id: trailing-whitespace 33 | # check mixed line ending 34 | - id: mixed-line-ending 35 | args: ['--fix=lf'] # make lines end with LF 36 | # sort all entries and remove entries with version 0.0.0 in requirements.txt 37 | - id: requirements-txt-fixer 38 | # ensure that files in tests/ directory end in *_test.py 39 | - id: name-tests-test 40 | # args: ['--django'] # allow test*.py 41 | # check against PEP 8 42 | - id: flake8 43 | args: ['--ignore=E111,E114,E121,E122,E127,E501,W503'] # ignore specific errors 44 | 45 | # - repo: https://github.com/pre-commit/mirrors-yapf.git 46 | # rev: v0.26.0 47 | # hooks: 48 | # # apply a Python style guide 49 | # - id: yapf 50 | # exclude: 'migrations|tests' 51 | # args: ['--style=.style.yapf', '--parallel', '--in-place'] 52 | # args: ['--style={based_on_style:pep8,indent_width:2,continuation_indent_width:2}', 53 | # '--parallel', 54 | # '--in-place'] 55 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import re 2 | from setuptools import find_packages 3 | from setuptools import setup 4 | 5 | 6 | def find_version(path): 7 | with open(path, 'r', encoding='utf-8') as stream: 8 | return re.search( 9 | r"^__version__ = ['\"]([^'\"]*)['\"]", 10 | stream.read(), 11 | re.M, 12 | ).group(1) 13 | 14 | 15 | def read_long_description(path): 16 | with open(path, 'r', encoding='utf-8') as stream: 17 | return stream.read() 18 | 19 | 20 | # @see https://python-packaging.readthedocs.io/en/latest/index.html # noqa 21 | # @see https://setuptools.readthedocs.io/en/latest/setuptools.html#new-and-changed-setup-keywords # noqa 22 | # @see https://packaging.python.org/guides/packaging-namespace-packages/#native-namespace-packages # noqa 23 | setup( 24 | name='LAIDD-mol-gen', 25 | version=find_version('laiddmg/__init__.py'), 26 | author='Il Gu Yi', 27 | author_email='ilgu.yi.219@gmail.com', 28 | url='https://github.com/ilguyi/LAIDD-molecular-generation', 29 | license='MIT', 30 | description='Molecular generation modules for LAIDD Lecture', 31 | long_description=read_long_description('README.md'), 32 | long_description_content_type='text/markdown', 33 | packages=find_packages(), 34 | classifiers=[ 35 | 'Environment :: GPU :: NVIDIA CUDA :: 11.0', 36 | 'License :: OSI Approved :: MIT License', 37 | 'Operating System :: OS Independent', 38 | 'Programming Language :: Python :: 3', 39 | 'Programming Language :: Python :: 3.7', 40 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 41 | ], 42 | keywords=[''], 43 | zip_safe=False, 44 | python_requires='>=3.7', 45 | install_requires=[ 46 | # LIST THE DEPENDENCIES OF YOUR PACKAGE 47 | # FOR INSTANCE, 48 | # 'numpy>=1.19.2, 49 | 'torch>=1.0', 50 | 'numpy', 51 | 'pandas', 52 | # 'scikit-learn', 53 | # 'scipy', 54 | 'flake8', 55 | 'easydict', 56 | 'notebook', 57 | 'matplotlib', 58 | 'pre-commit', 59 | 'tensorboard', 60 | 'tqdm', 61 | ], 62 | entry_points={ 63 | 'console_scripts': [ 64 | 'laiddmg-train = laiddmg.train:main', 65 | 'laiddmg-generate = laiddmg.generate:main', 66 | ] 67 | }, 68 | ) 69 | -------------------------------------------------------------------------------- /laiddmg/modeling_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from abc import abstractmethod 3 | from typing import List, Union, Any 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | from .tokenization_utils import Tokenizer 9 | 10 | from . import logging_utils 11 | 12 | 13 | logger = logging_utils.get_logger(__name__) 14 | 15 | 16 | class BaseModel(nn.Module): 17 | 18 | @property 19 | def device(self) -> torch.device: 20 | try: 21 | return next(self.parameters()).device 22 | except StopIteration: 23 | return 0 24 | 25 | @property 26 | def dtype(self) -> torch.dtype: 27 | try: 28 | return next(self.parameters()).dtype 29 | except StopIteration: 30 | return 0 31 | 32 | def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int: 33 | 34 | def parameter_filter(x): 35 | return (x.requires_grad or not only_trainable) and not ( 36 | isinstance(x, nn.Embedding) and exclude_embeddings 37 | ) 38 | 39 | params = filter(parameter_filter, self.parameters()) if only_trainable else self.parameters() 40 | return sum(p.numel() for p in params) 41 | 42 | @abstractmethod 43 | def generate(self, **kwargs): 44 | pass 45 | 46 | @classmethod 47 | def from_pretrained( 48 | cls, 49 | config: Any, 50 | ckpt_path: Union[str, os.PathLike], 51 | output_info: bool = False 52 | ): 53 | model = cls(config) 54 | state_dict = torch.load(ckpt_path, map_location='cpu') 55 | 56 | epoch = state_dict['epoch'] 57 | global_step = state_dict['global_step'] 58 | model.load_state_dict(state_dict['model_state_dict']) 59 | logger.info('All keys matched successfully and success to load') 60 | 61 | if output_info: 62 | return model, epoch, global_step 63 | else: 64 | return model 65 | 66 | def postprocessing( 67 | self, 68 | generated_sequences: torch.LongTensor, 69 | tokenizer: Tokenizer, 70 | ) -> List[List[int]]: 71 | pad_index = tokenizer.convert_token_to_id(tokenizer.pad_token) 72 | end_index = tokenizer.convert_token_to_id(tokenizer.end_token) 73 | new_sequences = [] 74 | for sequence in generated_sequences: 75 | new_seq = [] 76 | for token in sequence: 77 | if token.item() not in [pad_index, end_index]: 78 | new_seq.append(token.item()) 79 | else: 80 | break 81 | new_sequences.append(new_seq) 82 | 83 | return new_sequences 84 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /laiddmg/datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List, Dict, Union, Callable 3 | 4 | import numpy as np 5 | import pandas as pd 6 | 7 | import torch 8 | from torch.utils.data import Dataset 9 | 10 | from .tokenization_utils import Tokenizer 11 | 12 | 13 | from . import logging_utils 14 | 15 | 16 | logger = logging_utils.get_logger(__name__) 17 | 18 | 19 | def get_rawdataset(split: str = 'train', 20 | path: Union[str, os.PathLike] = '../datasets/moses') -> np.ndarray: 21 | 22 | assert split in ['train', 'test'] 23 | 24 | smiles_path = os.path.join(path, f'{split}.csv.gz') 25 | 26 | logger.info(f'read {smiles_path} file') 27 | smiles = pd.read_csv(smiles_path)['smiles'].values 28 | 29 | logger.info(f'number of {split} dataset: {len(smiles)}') 30 | 31 | return smiles 32 | 33 | 34 | class SMILESDataset(Dataset): 35 | 36 | def __init__(self, 37 | smiles: Union[List[str], np.ndarray], 38 | tokenizer: Tokenizer, 39 | transfrom: Callable = None): 40 | 41 | self.smiles = np.asarray(smiles) 42 | assert len(self.smiles.shape) == 1, 'dataset must be `1-D` array' 43 | self.tokenizer = tokenizer 44 | self.transfrom = transfrom 45 | 46 | def __len__(self) -> int: 47 | return len(self.smiles) 48 | 49 | def __getitem__(self, index: int) -> Dict[str, Union[List[str], List[int], int]]: 50 | smiles = self.smiles[index] 51 | input_id = self.tokenizer(smiles, return_tensors=False) 52 | 53 | sample = { 54 | 'index': index, 55 | 'smiles': smiles, 56 | 'input_id': input_id[:-1], 57 | 'target': input_id[1:], 58 | 'length': len(input_id[:-1]) 59 | } 60 | 61 | if self.transfrom: 62 | sample = self.transfrom(sample) 63 | 64 | return sample 65 | 66 | 67 | class ToTensor: 68 | 69 | def __init__(self): 70 | pass 71 | 72 | def __call__(self, sample: Dict) -> Dict: 73 | inp = sample['input_id'] 74 | tar = sample['target'] 75 | length = sample['length'] 76 | sample['input_id'] = torch.LongTensor(inp) 77 | sample['target'] = torch.LongTensor(tar) 78 | sample['length'] = torch.LongTensor([length]) 79 | 80 | return sample 81 | 82 | 83 | def get_dataset(smiles: Union[List[str], np.ndarray] = None, 84 | tokenizer: Tokenizer = None) -> SMILESDataset: 85 | if smiles is not None: 86 | if tokenizer is not None: 87 | return SMILESDataset(smiles, 88 | tokenizer, 89 | transfrom=ToTensor()) 90 | else: 91 | raise ValueError('tokenizer must be needed.') 92 | else: 93 | raise ValueError('smiles must be needed.') 94 | -------------------------------------------------------------------------------- /laiddmg/generate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import os 4 | import sys 5 | from datetime import datetime 6 | import pandas as pd 7 | 8 | from laiddmg import ( 9 | get_generate_args, 10 | set_output_dir_for_generation, 11 | get_batch_size_list_for_generate, 12 | measure_duration_time, 13 | CharRNNConfig, 14 | VAEConfig, 15 | Tokenizer, 16 | CharRNNModel, 17 | VAEModel, 18 | ) 19 | 20 | from . import logging_utils 21 | 22 | 23 | def main(): 24 | 25 | start_time = datetime.now() 26 | # get training args 27 | args = get_generate_args() 28 | model_type = sys.argv[1] 29 | args = set_output_dir_for_generation(model_type, args) 30 | logger = logging_utils.get_logger(__name__, os.path.join(args.output_dir, 'out.log')) 31 | 32 | logger.info(f'args: {args}') 33 | logger.info(f'model type: {model_type}') 34 | logger.info(f'use device: {args.device}') 35 | 36 | # get tokenizer, config, and model 37 | assert model_type in ['char_rnn', 'vae'] 38 | if model_type == 'char_rnn': 39 | config = CharRNNConfig.from_pretrained(os.path.join(f'{args.checkpoint_dir}', 'config.json')) 40 | tokenizer = Tokenizer() 41 | model = CharRNNModel.from_pretrained(config, os.path.join(f'{args.checkpoint_dir}', f'{args.weights_name}')) 42 | else: 43 | config = VAEConfig.from_pretrained(os.path.join(f'{args.checkpoint_dir}', 'config.json')) 44 | tokenizer = Tokenizer() 45 | model = VAEModel.from_pretrained(config, os.path.join(f'{args.checkpoint_dir}', f'{args.weights_name}')) 46 | 47 | logger.info(f'model type: {config.model_type}') 48 | logger.info(f'model config: {config}') 49 | logger.info(f'tokenizer vocab: {tokenizer.vocab}') 50 | logger.info(f'model: {model}') 51 | logger.info(f'model device: {model.device}') 52 | logger.info(f'model dtype: {model.dtype}') 53 | 54 | model.to(args.device) 55 | logger.info(f'generate on device: {model.device}') 56 | model.eval() 57 | 58 | batch_size_list = get_batch_size_list_for_generate(args) 59 | generated_smiles = [] 60 | for bs in batch_size_list: 61 | outputs = model.generate(tokenizer=tokenizer, 62 | max_length=args.max_length, 63 | num_return_sequences=bs, 64 | skip_special_tokens=True) 65 | generated_smiles += outputs 66 | 67 | savefile_path = os.path.join(args.output_dir, 'generated_smiles.csv') 68 | logger.info(f'file path to save: {savefile_path}') 69 | pd.DataFrame({'smiles': generated_smiles}).to_csv( 70 | savefile_path, header=True, index=False 71 | ) 72 | 73 | end_time = datetime.now() 74 | measure_duration_time(end_time - start_time) 75 | 76 | 77 | if __name__ == '__main__': 78 | main() 79 | -------------------------------------------------------------------------------- /laiddmg/logging_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Qptuna, Hugging Face 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # modified by Il Gu Yi, 2021 16 | # original source code is at 17 | # `https://github.com/huggingface/transformers/tree/master/src/transformers/utils/logging.py` 18 | """ Logging utilities. """ 19 | 20 | 21 | import logging 22 | import os 23 | import sys 24 | import threading 25 | 26 | from typing import Optional 27 | 28 | 29 | _lock = threading.Lock() 30 | _stream_handler: str = None 31 | _file_handler: str = None 32 | library_root_logger: str = None 33 | 34 | 35 | log_levels = { 36 | 'debug': logging.DEBUG, 37 | 'info': logging.INFO, 38 | 'warning': logging.WARNING, 39 | 'error': logging.ERROR, 40 | 'critical': logging.CRITICAL, 41 | } 42 | 43 | _default_log_level = log_levels['info'] 44 | 45 | 46 | def _get_library_name() -> str: 47 | 48 | return __name__.split('.')[0] 49 | 50 | 51 | def _get_library_root_logger() -> logging.Logger: 52 | 53 | return logging.getLogger(_get_library_name()) 54 | 55 | 56 | def _configure_library_root_logger(log_path: str = None) -> None: 57 | 58 | global _stream_handler, _file_handler, library_root_logger 59 | 60 | formatter = logging.Formatter('[%(levelname)s|%(filename)s:%(lineno)s %(asctime)s >> %(message)s') 61 | 62 | with _lock: 63 | 64 | if _stream_handler is not None: 65 | pass 66 | else: 67 | _stream_handler = logging.StreamHandler() 68 | _stream_handler.flush = sys.stderr.flush 69 | 70 | library_root_logger = _get_library_root_logger() 71 | library_root_logger.setLevel(_default_log_level) 72 | 73 | library_root_logger.addHandler(_stream_handler) 74 | library_root_logger.propagate = False 75 | 76 | _stream_handler.setFormatter(formatter) 77 | 78 | if log_path is not None: 79 | if _file_handler is not None: 80 | raise ValueError(f'{log_path} must be one file.') 81 | else: 82 | log_dir = os.path.dirname(log_path) 83 | if log_dir != '': 84 | os.makedirs(log_dir, exist_ok=True) 85 | _file_handler = logging.FileHandler(log_path) 86 | library_root_logger.addHandler(_file_handler) 87 | 88 | _file_handler.setFormatter(formatter) 89 | 90 | 91 | def get_logger(name: Optional[str] = None, 92 | log_path: str = None) -> logging.Logger: 93 | 94 | if name is None: 95 | name = _get_library_name() 96 | 97 | _configure_library_root_logger(log_path) 98 | 99 | return logging.getLogger(name) 100 | -------------------------------------------------------------------------------- /laiddmg/trainer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from abc import abstractmethod 4 | from typing import Dict, Union, Any, Optional 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | from torch.utils.data import Dataset, DataLoader 10 | # from torch.utils.tensorboard import SummaryWriter 11 | 12 | from .tokenization_utils import Tokenizer 13 | from .utils import ( 14 | set_seed, 15 | args_and_config_to_json_files, 16 | ) 17 | 18 | 19 | from . import logging_utils 20 | 21 | 22 | logger = logging_utils.get_logger(__name__) 23 | 24 | 25 | class Trainer: 26 | 27 | def __init__( 28 | self, 29 | model: nn.Module = None, 30 | args: argparse.Namespace = None, 31 | train_dataset: Dataset = None, 32 | tokenizer: Optional[Tokenizer] = None, 33 | optimizer: optim.Optimizer = None, 34 | scheduler: optim.lr_scheduler = None, 35 | **kwargs, 36 | ): 37 | self.model = model 38 | self.args = args 39 | self.train_dataset = train_dataset 40 | self.tokenizer = tokenizer 41 | self.optimizer = optimizer 42 | self.scheduler = scheduler 43 | self.global_step = 0 44 | 45 | set_seed(self.args.seed) 46 | if model is None: 47 | raise RuntimeError('`Trainer requires a `model` arguments.') 48 | 49 | logger.info('`training from scratch`') 50 | 51 | # tb_log_dir = os.path.join(args.output_dir, 'logs') 52 | # self.tb_writer = SummaryWriter(tb_log_dir) 53 | # logger.info(f'Created tensorboard writer in {tb_log_dir}.') 54 | 55 | if self.args.device == torch.device('cuda:0'): 56 | logger.info('Use one gpu') 57 | else: 58 | logger.info('Use only cpu') 59 | self.model = self.model.to(self.args.device) 60 | 61 | # get train_dataloader 62 | self.train_dataloader = self.get_train_dataloader() 63 | # add num_train_steps_per_epoch to args 64 | self.args.num_training_steps_per_epoch = len(self.train_dataloader) 65 | self.args.num_training_steps = len(self.train_dataloader) * args.num_train_epochs 66 | logger.info(f'the number of training steps per epoch: {self.args.num_training_steps_per_epoch}') 67 | logger.info(f'the total number of training steps per epoch: {self.args.num_training_steps}') 68 | 69 | args_and_config_to_json_files(self.args, self.model.config) 70 | 71 | @abstractmethod 72 | def _collate_fn(self, **kwargs): 73 | pass 74 | 75 | def get_train_dataloader(self) -> DataLoader: 76 | if self.train_dataset is None: 77 | raise ValueError('Trainer: training requires a `train_dataset`.') 78 | 79 | return DataLoader( 80 | self.train_dataset, 81 | batch_size=self.args.train_batch_size, 82 | shuffle=True, 83 | collate_fn=self._collate_fn, 84 | num_workers=16, 85 | ) 86 | 87 | def _prepare_inputs( 88 | self, 89 | inputs: Dict[str, Union[torch.Tensor, Any]], 90 | ) -> Dict[str, Union[torch.Tensor, Any]]: 91 | # This function is borrowed from `huggingface.transformer` 92 | for k, v in inputs.items(): 93 | if isinstance(v, torch.Tensor): 94 | inputs[k] = v.to(self.args.device) 95 | 96 | return inputs 97 | 98 | @abstractmethod 99 | def train(self): 100 | pass 101 | 102 | def save_model(self, epoch: int): 103 | checkpoint_dir = os.path.join(self.args.output_dir) 104 | ckpt_name = f'ckpt_{epoch:03d}.pt' 105 | ckpt_path = os.path.join(checkpoint_dir, ckpt_name) 106 | 107 | torch.save({'epoch': epoch, 108 | 'global_step': self.global_step, 109 | 'model_state_dict': self.model.state_dict()}, 110 | ckpt_path) 111 | logger.info(f'saved {self.model.config.model_type} model at epoch {epoch}.') 112 | -------------------------------------------------------------------------------- /laiddmg/common_parser.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | 5 | from argparse import ArgumentParser 6 | 7 | from .models.char_rnn.char_rnn_trainer import train_parser as char_rnn_parser 8 | from .models.vae.vae_trainer import train_parser as vae_parser 9 | 10 | from .models.char_rnn.char_rnn_trainer import generate_parser as char_rnn_g_parser 11 | from .models.vae.vae_trainer import generate_parser as vae_g_parser 12 | 13 | from . import logging_utils 14 | 15 | 16 | logger = logging_utils.get_logger(__name__) 17 | 18 | 19 | def add_common_args(parser: ArgumentParser) -> None: 20 | parser = parser.add_argument_group('common') 21 | 22 | parser.add_argument('--seed', 23 | type=int, 24 | default=219, 25 | help='seed number') 26 | parser.add_argument('--output_dir', 27 | default=None, 28 | type=str, 29 | help='directory where to save checkpoint') 30 | 31 | 32 | def add_train_args(parser: ArgumentParser) -> None: 33 | add_common_args(parser) 34 | 35 | parser = parser.add_argument_group('train') 36 | 37 | parser.add_argument('--dataset_path', 38 | default='../datasets', 39 | type=str, 40 | help='dataset path where train, test datasets are') 41 | parser.add_argument('--log_steps', 42 | default=10, 43 | type=int, 44 | help='number of steps before two (tensorboard) logs write.') 45 | 46 | 47 | def get_train_parser() -> ArgumentParser: 48 | parser = ArgumentParser('Molecular generation train tool', 49 | usage='laiddmg-train []') 50 | subparser = parser.add_subparsers() 51 | 52 | # get parser of all models 53 | add_train_args(char_rnn_parser(subparser)) 54 | add_train_args(vae_parser(subparser)) 55 | 56 | return parser 57 | 58 | 59 | def add_generate_args(parser: ArgumentParser) -> None: 60 | add_common_args(parser) 61 | 62 | parser = parser.add_argument_group('generate') 63 | 64 | parser.add_argument('--checkpoint_dir', 65 | type=str, 66 | required=True, 67 | help='directory where to load checkpoint') 68 | parser.add_argument('--weights_name', 69 | default=None, 70 | type=str, 71 | help='checkpoint file name to load weights') 72 | parser.add_argument('--num_generation', 73 | default=10000, 74 | type=int, 75 | help='the number of generated SMILES') 76 | parser.add_argument('--max_length', 77 | default=128, 78 | type=int, 79 | help='the maximum length of the sequence to be generated') 80 | 81 | 82 | def get_generate_parser() -> ArgumentParser: 83 | parser = ArgumentParser('Molecular generation generate tool', 84 | usage='laiddmg-generate []') 85 | subparser = parser.add_subparsers() 86 | 87 | # get parser of all models 88 | add_generate_args(char_rnn_g_parser(subparser)) 89 | add_generate_args(vae_g_parser(subparser)) 90 | 91 | return parser 92 | 93 | 94 | def setup_devices(args) -> torch.device: 95 | logger.info('PyTorch: setting up devices') 96 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 97 | logger.info(f'set torch.device: {device}') 98 | 99 | if device.type == 'cuda': 100 | torch.cuda.set_device(device) 101 | 102 | return device 103 | 104 | 105 | def get_train_args() -> argparse.Namespace: 106 | parser = get_train_parser() 107 | args = parser.parse_args() 108 | 109 | args.device = setup_devices(args) 110 | 111 | return args 112 | 113 | 114 | def get_generate_args() -> argparse.Namespace: 115 | parser = get_generate_parser() 116 | args = parser.parse_args() 117 | 118 | args.device = setup_devices(args) 119 | 120 | return args 121 | -------------------------------------------------------------------------------- /laiddmg/models/char_rnn/modeling.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple, Union 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.nn.utils.rnn as rnn_utils 7 | 8 | from .configuration import CharRNNConfig 9 | from ...tokenization_utils import Tokenizer 10 | from ...modeling_utils import BaseModel 11 | from ... import logging_utils 12 | 13 | logger = logging_utils.get_logger(__name__) 14 | 15 | 16 | class CharRNNModel(BaseModel): 17 | 18 | def __init__(self, config: CharRNNConfig): 19 | super(CharRNNModel, self).__init__() 20 | self.config = config 21 | self.vocab_size = config.vocab_size 22 | self.embedding_dim = config.embedding_dim 23 | self.hidden_dim = config.hidden_dim 24 | self.num_layers = config.num_layers 25 | self.dropout = config.dropout 26 | self.padding_value = config.padding_value 27 | self.output_dim = self.vocab_size 28 | 29 | self.embeddings = nn.Embedding(self.vocab_size, self.embedding_dim, 30 | padding_idx=self.padding_value) 31 | self.lstm = nn.LSTM(self.embedding_dim, self.hidden_dim, 32 | self.num_layers, 33 | batch_first=True, 34 | dropout=self.dropout) 35 | self.fc = nn.Linear(self.hidden_dim, self.output_dim) 36 | 37 | def forward( 38 | self, 39 | input_ids: torch.Tensor, # (batch_size, seq_len) 40 | lengths: torch.Tensor, # (batch_size,) 41 | hiddens: Tuple[torch.Tensor] = None, # (num_layers, batch_size, hidden_dim) 42 | **kwargs, 43 | ) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]: 44 | x = self.embeddings(input_ids) # x: (batch_size, seq_len, embedding_dim) 45 | x = rnn_utils.pack_padded_sequence( 46 | x, 47 | lengths.cpu(), 48 | batch_first=True, 49 | enforce_sorted=False, 50 | ) 51 | x, hiddens = self.lstm(x, hiddens) 52 | # hiddens: (h, c); (num_layers, batch_size, hidden_dim), respectively 53 | x, _ = rnn_utils.pad_packed_sequence( 54 | x, 55 | batch_first=True, 56 | ) # x: (batch_size, seq_len, hidden_dim) 57 | outputs = self.fc(x) # outputs: (batch_size, seq_len, vocab_size) 58 | 59 | return outputs, hiddens 60 | 61 | def reset_states(self, batch_size: int): 62 | h0 = torch.zeros(self.num_layers, batch_size, self.hidden_dim).to(self.device) 63 | c0 = torch.zeros(self.num_layers, batch_size, self.hidden_dim).to(self.device) 64 | 65 | return (h0, c0) 66 | 67 | @torch.no_grad() 68 | def generate( 69 | self, 70 | tokenizer: Tokenizer = None, 71 | max_length: int = 128, 72 | num_return_sequences: int = 1, 73 | skip_special_tokens: bool = False, 74 | **kwargs, 75 | ) -> Union[List[List[int]], List[List[str]]]: 76 | 77 | initial_inputs = torch.full((num_return_sequences, 1), 78 | tokenizer.convert_token_to_id(tokenizer.start_token), 79 | dtype=torch.long, 80 | device=self.device) 81 | generated_sequences = initial_inputs 82 | input_ids = initial_inputs # input_ids: [batch_size, 1] 83 | hiddens = self.reset_states(num_return_sequences) 84 | 85 | for i in range(max_length + 1): 86 | x = self.embeddings(input_ids) # x: [batch_size, 1, embedding_dim] 87 | x, hiddens = self.lstm(x, hiddens) # x: [batch_size, 1, hidden_dim] 88 | logits = self.fc(x) # logits: [batch_size, 1, vocab_size] 89 | next_token_logits = logits.squeeze(1) # next_token_logits: [batch_size, vocab_size] 90 | 91 | probabilities = F.softmax(next_token_logits, dim=-1) # probabilities: [batch_size, vocab_size] 92 | next_tokens = torch.multinomial(probabilities, num_samples=1) 93 | # next_tokens: [batch_size, 1] 94 | 95 | input_ids = next_tokens 96 | generated_sequences = torch.cat([generated_sequences, next_tokens], dim=1) 97 | # generated_sequences: [batch_size, max_length] 98 | 99 | generated_sequences = self.postprocessing(generated_sequences, tokenizer) 100 | 101 | generated_SMILES = [] 102 | for sequence in generated_sequences: 103 | generated_SMILES.append(tokenizer.decode(sequence, skip_special_tokens)) 104 | 105 | return generated_SMILES 106 | -------------------------------------------------------------------------------- /laiddmg/models/char_rnn/char_rnn_trainer.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from typing import List, Dict, Union 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | 8 | from ...trainer import Trainer 9 | from ... import logging_utils 10 | 11 | 12 | logger = logging_utils.get_logger(__name__) 13 | 14 | 15 | def train_parser(parser: ArgumentParser) -> ArgumentParser: 16 | if parser is None: 17 | parser = ArgumentParser() 18 | 19 | char_rnn_parser = parser.add_parser('char_rnn') 20 | 21 | char_rnn_parser.add_argument('--num_train_epochs', 22 | default=50, 23 | type=int, 24 | help='number of epochs for training') 25 | char_rnn_parser.add_argument('--train_batch_size', 26 | default=64, 27 | type=int, 28 | help='batch size per device for training') 29 | char_rnn_parser.add_argument('--lr', 30 | default=1e-3, 31 | type=float, 32 | help='learning rate for training') 33 | char_rnn_parser.add_argument('--step_size', 34 | default=10, 35 | type=int, 36 | help='period of learning rate decay (decay unit: epoch)') 37 | char_rnn_parser.add_argument('--gamma', 38 | default=0.5, 39 | type=float, 40 | help='multiplicative factor of learning rate decay') 41 | 42 | return char_rnn_parser 43 | 44 | 45 | def generate_parser(parser: ArgumentParser) -> ArgumentParser: 46 | if parser is None: 47 | parser = ArgumentParser() 48 | 49 | char_rnn_parser = parser.add_parser('char_rnn') 50 | 51 | char_rnn_parser.add_argument('--batch_size_for_generation', 52 | default=128, 53 | type=int, 54 | help='batch size for generation') 55 | 56 | return char_rnn_parser 57 | 58 | 59 | class CharRNNTrainer(Trainer): 60 | 61 | def __init__(self, **kwargs): 62 | super(CharRNNTrainer, self).__init__(**kwargs) 63 | pass 64 | 65 | def _pad_sequence(self, 66 | data: List[torch.Tensor], 67 | padding_value: int = 0) -> torch.Tensor: 68 | return torch.nn.utils.rnn.pad_sequence(data, 69 | batch_first=True, 70 | padding_value=padding_value) 71 | 72 | def _collate_fn(self, 73 | batch: List[Dict[str, Union[torch.Tensor, str, int]]], 74 | **kwargs) -> Dict[str, Union[torch.Tensor, List[str], List[int]]]: 75 | 76 | indexes = [item['index'] for item in batch] 77 | smiles = [item['smiles'] for item in batch] 78 | input_ids = [item['input_id'] for item in batch] 79 | targets = [item['target'] for item in batch] 80 | lengths = [item['length'] for item in batch] 81 | 82 | padding_value = self.tokenizer.padding_value 83 | input_ids = self._pad_sequence(input_ids, padding_value) 84 | targets = self._pad_sequence(targets, padding_value) 85 | lengths = torch.LongTensor(lengths) 86 | 87 | return {'input_ids': input_ids, 88 | 'targets': targets, 89 | 'lengths': lengths, 90 | 'smiles': smiles, 91 | 'indexes': indexes} 92 | 93 | def _train_step( 94 | self, 95 | data: Dict[str, Union[torch.Tensor, List[str], List[int]]], 96 | loss_fn: 'nn.modules.loss', 97 | optimizer: 'optim' 98 | ) -> float: 99 | optimizer.zero_grad() 100 | 101 | data = self._prepare_inputs(data) 102 | outputs, _ = self.model(**data) 103 | 104 | loss = loss_fn(outputs.view(-1, outputs.shape[-1]), 105 | data['targets'].view(-1)) 106 | 107 | loss.backward() 108 | optimizer.step() 109 | self.global_step += 1 110 | 111 | return loss.item() 112 | 113 | def _train_epoch( 114 | self, 115 | epoch: int, 116 | loss_fn: 'nn.modules.loss', 117 | optimizer: 'optim', 118 | lr_scheduler: 'optim.lr_scheduler', 119 | ): 120 | self.model.train() 121 | 122 | for i, data in enumerate(self.train_dataloader): 123 | loss = self._train_step(data, loss_fn, optimizer) 124 | 125 | if self.global_step % self.args.log_steps == 0: 126 | logger.info( 127 | f'{epoch} Epochs | {i + 1}/{self.args.num_training_steps_per_epoch} | loss: {loss:.4g} | ' 128 | f'lr: {lr_scheduler.get_last_lr()[0]:.4g}' 129 | ) 130 | 131 | lr_scheduler.step() 132 | 133 | def train(self): 134 | 135 | loss_fn = nn.CrossEntropyLoss(ignore_index=self.tokenizer.padding_value) 136 | optimizer = optim.Adam(self.model.parameters(), lr=self.args.lr) 137 | lr_scheduler = optim.lr_scheduler.StepLR(optimizer, 138 | self.args.step_size, 139 | self.args.gamma) 140 | 141 | for epoch in range(1, self.args.num_train_epochs + 1): 142 | logger.info(f'Start training: {epoch} Epoch') 143 | 144 | self._train_epoch(epoch, loss_fn, optimizer, lr_scheduler) 145 | self.save_model(epoch) 146 | 147 | logger.info('Training done!!') 148 | -------------------------------------------------------------------------------- /laiddmg/models/vae/vae_trainer.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from typing import List, Dict, Tuple, Union 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | 8 | from ...trainer import Trainer 9 | from ...utils import AnnealingSchedules 10 | from ... import logging_utils 11 | 12 | 13 | logger = logging_utils.get_logger(__name__) 14 | 15 | 16 | def train_parser(parser: ArgumentParser) -> ArgumentParser: 17 | if parser is None: 18 | parser = ArgumentParser() 19 | 20 | vae_parser = parser.add_parser('vae') 21 | 22 | vae_parser.add_argument('--num_train_epochs', 23 | default=100, 24 | type=int, 25 | help='number of epochs for training') 26 | vae_parser.add_argument('--train_batch_size', 27 | default=512, 28 | type=int, 29 | help='batch size per device for training') 30 | vae_parser.add_argument('--lr', 31 | default=1e-4, 32 | type=float, 33 | help='learning rate for training') 34 | 35 | return vae_parser 36 | 37 | 38 | def generate_parser(parser: ArgumentParser) -> ArgumentParser: 39 | if parser is None: 40 | parser = ArgumentParser() 41 | 42 | vae_parser = parser.add_parser('vae') 43 | 44 | vae_parser.add_argument('--batch_size_for_generation', 45 | default=128, 46 | type=int, 47 | help='batch size for generation') 48 | 49 | return vae_parser 50 | 51 | 52 | class VAETrainer(Trainer): 53 | 54 | def __init__(self, **kwargs): 55 | super(VAETrainer, self).__init__(**kwargs) 56 | pass 57 | 58 | def _pad_sequence(self, 59 | data: List[torch.Tensor], 60 | padding_value: int = 0) -> torch.Tensor: 61 | return torch.nn.utils.rnn.pad_sequence(data, 62 | batch_first=True, 63 | padding_value=padding_value) 64 | 65 | def _collate_fn(self, 66 | batch: List[Dict[str, Union[torch.Tensor, str, int]]], 67 | **kwargs) -> Dict[str, Union[torch.Tensor, List[str], List[int]]]: 68 | 69 | indexes = [item['index'] for item in batch] 70 | smiles = [item['smiles'] for item in batch] 71 | input_ids = [item['input_id'] for item in batch] 72 | targets = [item['target'] for item in batch] 73 | lengths = [item['length'] for item in batch] 74 | 75 | padding_value = self.tokenizer.padding_value 76 | input_ids = self._pad_sequence(input_ids, padding_value) 77 | targets = self._pad_sequence(targets, padding_value) 78 | lengths = torch.LongTensor(lengths) 79 | 80 | return {'input_ids': input_ids, 81 | 'targets': targets, 82 | 'lengths': lengths, 83 | 'smiles': smiles, 84 | 'indexes': indexes} 85 | 86 | def _train_step( 87 | self, 88 | data: Dict[str, Union[torch.Tensor, List[str], List[int]]], 89 | loss_fn: 'nn.modules.loss', 90 | optimizer: 'optim' 91 | ) -> Tuple[float]: 92 | optimizer.zero_grad() 93 | 94 | data = self._prepare_inputs(data) 95 | outputs, z_mu, z_logvar = self.model(**data) 96 | 97 | reconstruction_loss = loss_fn( 98 | outputs.view(-1, outputs.shape[-1]), 99 | data['targets'].view(-1) 100 | ) 101 | 102 | kl_loss = .5 * (torch.exp(z_logvar) + z_mu**2 - 1. - z_logvar).sum(1).mean() 103 | 104 | kl_annealing_weight = self.kl_annealing(self.global_step) 105 | 106 | total_loss = reconstruction_loss + kl_annealing_weight * kl_loss 107 | 108 | total_loss.backward() 109 | nn.utils.clip_grad_norm_(self.model.parameters(), 110 | max_norm=50) 111 | optimizer.step() 112 | self.global_step += 1 113 | 114 | return total_loss.item(), reconstruction_loss.item(), kl_loss.item() 115 | 116 | def _train_epoch( 117 | self, 118 | epoch: int, 119 | loss_fn: 'nn.modules.loss', 120 | optimizer: 'optim', 121 | ): 122 | self.model.train() 123 | 124 | for i, data in enumerate(self.train_dataloader): 125 | total_loss, reconstruction_loss, kl_loss = self._train_step(data, loss_fn, optimizer) 126 | 127 | if self.global_step % self.args.log_steps == 0: 128 | logger.info( 129 | f'{epoch} Epochs | {i + 1}/{self.args.num_training_steps_per_epoch} | reconst_loss: {reconstruction_loss:.4g} ' 130 | f'kl_loss: {kl_loss:.4g}, total_loss: {total_loss:.4g}, kl_annealing: {self.kl_annealing(self.global_step -1):.4g}' 131 | ) 132 | 133 | def train(self): 134 | 135 | reconstruction_loss_fn = nn.CrossEntropyLoss(ignore_index=self.tokenizer.padding_value) 136 | optimizer = optim.Adam(self.model.parameters(), lr=self.args.lr) 137 | 138 | # setup kl annealing weight 139 | self.kl_annealing = AnnealingSchedules( 140 | method='cycle_linear', 141 | update_unit='epoch', 142 | num_training_steps=self.args.num_training_steps, 143 | num_training_steps_per_epoch=self.args.num_training_steps_per_epoch, 144 | start_weight=0.0, 145 | stop_weight=0.05, 146 | n_cycle=1, 147 | ratio=1.0, 148 | ) 149 | 150 | for epoch in range(1, self.args.num_train_epochs + 1): 151 | logger.info(f'Start training: {epoch} Epoch') 152 | 153 | self._train_epoch(epoch, reconstruction_loss_fn, optimizer) 154 | self.save_model(epoch) 155 | 156 | logger.info('Training done!!') 157 | -------------------------------------------------------------------------------- /laiddmg/utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import copy 3 | import datetime 4 | import json 5 | import os 6 | import random 7 | import numpy as np 8 | from typing import List, Any 9 | 10 | import torch 11 | 12 | from laiddmg import logging_utils 13 | 14 | 15 | OUTPUT_DIR = 'outputs' 16 | TRAINING_ARGS = 'training_ags.json' 17 | CONFIG_NAME = 'config.json' 18 | 19 | logger = logging_utils.get_logger(__name__) 20 | 21 | 22 | def set_output_dir(model_type: str, args: argparse.Namespace) -> argparse.Namespace: 23 | if args.output_dir is not None: 24 | output_dir = os.path.join(OUTPUT_DIR, model_type, args.output_dir) 25 | else: 26 | time_stamp = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S-%f') 27 | output_dir = os.path.join(OUTPUT_DIR, model_type, time_stamp) 28 | 29 | logger.info(f'output_dir: {output_dir}') 30 | args.output_dir = output_dir 31 | 32 | return args 33 | 34 | 35 | def set_output_dir_for_generation(model_type: str, args: argparse.Namespace) -> argparse.Namespace: 36 | if args.output_dir is not None: 37 | output_dir = os.path.join(args.checkpoint_dir, args.output_dir) 38 | else: 39 | output_dir = os.path.join(args.checkpoint_dir, 'generate') 40 | 41 | logger.info(f'output_dir: {output_dir}') 42 | args.output_dir = output_dir 43 | 44 | return args 45 | 46 | 47 | def measure_duration_time(duration_time: datetime.timedelta): 48 | days = duration_time.days 49 | seconds = duration_time.seconds 50 | hours, remainder = divmod(seconds, 3600) 51 | minutes, seconds = divmod(remainder, 60) 52 | print(f'total duration time: {days}days {hours}hours {minutes}minutes {seconds}seconds') 53 | 54 | 55 | def set_seed(seed: int = 219): 56 | logger.info(f'Set seed number {seed}') 57 | random.seed(seed) 58 | np.random.seed(seed) 59 | torch.manual_seed(seed) 60 | torch.cuda.manual_seed_all(seed) 61 | 62 | 63 | def args_to_json_file(args: argparse.Namespace): 64 | args_dict = copy.deepcopy(vars(args)) 65 | if args_dict['device'] == torch.device('cuda:0'): 66 | args_dict['device'] = 'cuda:0' 67 | else: 68 | args_dict['device'] = 'cpu' 69 | 70 | args_json_path = os.path.join(args.output_dir, TRAINING_ARGS) 71 | logger.info(f'write training args to `{args_json_path}`') 72 | with open(args_json_path, 'w', encoding='utf-8') as f: 73 | f.write(json.dumps(args_dict, indent=2, sort_keys=True)) 74 | 75 | 76 | def config_to_json_file(config: Any, output_dir: str = None): 77 | config_dict = copy.deepcopy(vars(config)) 78 | 79 | config_json_path = os.path.join(output_dir, CONFIG_NAME) 80 | logger.info(f'write model config to `{config_json_path}`') 81 | with open(config_json_path, 'w', encoding='utf-8') as f: 82 | f.write(json.dumps(config_dict, indent=2, sort_keys=True)) 83 | 84 | 85 | def args_and_config_to_json_files( 86 | args: argparse.Namespace, 87 | config: Any, 88 | ): 89 | args_to_json_file(args) 90 | config_to_json_file(config, args.output_dir) 91 | 92 | 93 | def get_batch_size_list_for_generate(args: argparse.Namespace) -> List[int]: 94 | num_iters, remainder = divmod(args.num_generation, args.batch_size_for_generation) 95 | if remainder != 0: 96 | batch_size_list = [args.batch_size_for_generation] * num_iters + [remainder] 97 | else: 98 | batch_size_list = [args.batch_size_for_generation] * num_iters 99 | 100 | return batch_size_list 101 | 102 | 103 | # This code (`Annealingschedules` class) that is borrowed from `https://github.com/haofuml/cyclical_annealing` 104 | # is modified by Il Gu Yi 105 | class AnnealingSchedules: 106 | 107 | def __init__(self, 108 | method: str = 'cycle_linear', 109 | update_unit: str = 'epoch', # ('step' or 'epoch') 110 | num_training_steps: int = None, 111 | num_training_steps_per_epoch: int = None, 112 | **kwargs): 113 | self.method = method 114 | assert update_unit in ['step', 'epoch'] 115 | self.update_unit = update_unit 116 | self.num_training_steps = num_training_steps 117 | self.num_training_steps_per_epoch = num_training_steps_per_epoch 118 | self.kwargs = kwargs 119 | 120 | self._calculate_annealing_schedule(**self.kwargs) 121 | 122 | def _get_annealing_value(self, w: float) -> float: 123 | if self.method == 'cycle_linear': 124 | return w 125 | elif self.method == 'cycle_sigmoid': 126 | return 1.0 / (1.0 + np.exp(- (w * 12. - 6.))) 127 | elif self.method == 'cycle_cosine': 128 | return .5 - .5 * np.cos(w * np.pi) 129 | 130 | def _calculate_annealing_schedule( 131 | self, 132 | start_weight: float = 0.0, 133 | stop_weight: float = 1.0, 134 | n_cycle: int = 1, 135 | ratio: float = 1.0, 136 | ): 137 | self.L = np.ones(self.num_training_steps) * stop_weight 138 | period = self.num_training_steps / n_cycle 139 | weight_step = (stop_weight - start_weight) / (period * ratio) # linear schedule 140 | 141 | for c in range(n_cycle): 142 | w, i = start_weight, 0 143 | while w <= stop_weight and (int(i + c * period) < self.num_training_steps): 144 | self.L[int(i + c * period)] = self._get_annealing_value(w) 145 | w += weight_step 146 | i += 1 147 | 148 | if self.update_unit == 'epoch': 149 | for global_step, w in enumerate(self.L): 150 | quotient = global_step // self.num_training_steps_per_epoch 151 | self.L[global_step] = self.L[quotient * self.num_training_steps_per_epoch] 152 | 153 | def __call__(self, global_step: int): 154 | assert global_step < self.num_training_steps 155 | return self.L[global_step] 156 | 157 | def get_annealing_schedule(self): 158 | return self.L 159 | -------------------------------------------------------------------------------- /laiddmg/tokenization_utils.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import List, Dict, Union, Optional 3 | 4 | import numpy as np 5 | 6 | import torch 7 | import torch.nn.utils.rnn as rnn_utils 8 | 9 | from . import logging_utils 10 | 11 | 12 | logger = logging_utils.get_logger(__name__) 13 | 14 | SPEICIAL_TOKENS = OrderedDict([ 15 | ('pad_token', ''), 16 | ('start_token', ''), 17 | ('end_token', ''), 18 | ('unknown_token', ''), 19 | ]) 20 | 21 | MOSES_VOCAB = OrderedDict([ 22 | (SPEICIAL_TOKENS['pad_token'], 0), 23 | (SPEICIAL_TOKENS['start_token'], 1), 24 | (SPEICIAL_TOKENS['end_token'], 2), 25 | (SPEICIAL_TOKENS['unknown_token'], 3), 26 | ('#', 4), 27 | ('(', 5), 28 | (')', 6), 29 | ('-', 7), 30 | ('1', 8), 31 | ('2', 9), 32 | ('3', 10), 33 | ('4', 11), 34 | ('5', 12), 35 | ('6', 13), 36 | ('=', 14), 37 | ('B', 15), 38 | ('C', 16), 39 | ('F', 17), 40 | ('H', 18), 41 | ('N', 19), 42 | ('O', 20), 43 | ('S', 21), 44 | ('[', 22), 45 | (']', 23), 46 | ('c', 24), 47 | ('l', 25), 48 | ('n', 26), 49 | ('o', 27), 50 | ('r', 28), 51 | ('s', 29), 52 | ]) 53 | 54 | 55 | class Tokenizer: 56 | 57 | def __init__(self, vocab_type: str = 'moses'): 58 | self.vocab_type = vocab_type 59 | self.vocab = MOSES_VOCAB 60 | self.ids_to_tokens = OrderedDict([(ids, tok) for tok, ids in self.vocab.items()]) 61 | 62 | @property 63 | def vocab_size(self) -> int: 64 | return len(self.vocab) 65 | 66 | def __len__(self) -> int: 67 | return len(self.vocab) 68 | 69 | @property 70 | def start_token(self) -> str: 71 | return SPEICIAL_TOKENS['start_token'] 72 | 73 | @property 74 | def end_token(self) -> str: 75 | return SPEICIAL_TOKENS['end_token'] 76 | 77 | @property 78 | def pad_token(self) -> str: 79 | return SPEICIAL_TOKENS['pad_token'] 80 | 81 | @property 82 | def padding_value(self) -> int: 83 | return self.vocab[self.pad_token] 84 | 85 | def __call__( 86 | self, 87 | text: Union[str, List[str]], 88 | add_special_tokens: str = 'both', # one of [`start`, `end`, `both`, `none`] 89 | max_length: Optional[int] = None, 90 | return_tensors: Optional[bool] = True, 91 | ) -> Dict: 92 | assert isinstance(text, str) or isinstance(text, (list, tuple)) or isinstance(text, np.ndarray), ( 93 | 'input must of type `str` (single example), `List[str]` (batch example)' 94 | ) 95 | 96 | assert add_special_tokens in ['start', 'end', 'both', 'none'] 97 | if isinstance(text, str): 98 | return self.encode(text, add_special_tokens, max_length, return_tensors) 99 | else: 100 | return self.batch_encode(text, add_special_tokens, max_length, return_tensors) 101 | 102 | def encode( 103 | self, 104 | text: str, 105 | add_special_tokens: str = 'both', 106 | max_length: Optional[int] = None, 107 | return_tensors: Optional[bool] = True, 108 | ) -> List[int]: 109 | tokens = self.tokenize(text) 110 | 111 | if add_special_tokens == 'both': 112 | special_tokens_len = 2 113 | elif add_special_tokens == 'none': 114 | special_tokens_len = 0 115 | else: 116 | special_tokens_len = 1 117 | total_len = len(tokens) + special_tokens_len 118 | 119 | if max_length is not None: 120 | num_tokens_to_remove = total_len - max_length 121 | else: 122 | num_tokens_to_remove = 0 123 | 124 | tokens = self.truncate_sequences(tokens, num_tokens_to_remove) 125 | tokens = self.add_special_tokens(tokens, add_special_tokens) 126 | tokens = self.convert_tokens_to_ids(tokens) 127 | 128 | if return_tensors: 129 | return {'input_ids': torch.LongTensor([tokens]), 130 | 'lengths': torch.LongTensor([len(tokens)])} 131 | else: 132 | return tokens 133 | 134 | def truncate_sequences( 135 | self, 136 | tokens: List[str], 137 | num_tokens_to_remove: int = 0, 138 | ) -> List[str]: 139 | if num_tokens_to_remove <= 0: 140 | return tokens 141 | else: 142 | return tokens[:-num_tokens_to_remove] 143 | 144 | def tokenize(self, text: str) -> List[str]: 145 | return [token for token in text] 146 | 147 | def add_special_tokens(self, tokens: List[str], mode: str) -> List[str]: 148 | if mode == 'both': 149 | return self.add_end_token(self.add_start_token(tokens)) 150 | elif mode == 'start': 151 | return self.add_start_token(tokens) 152 | elif mode == 'end': 153 | return self.add_end_token(tokens) 154 | else: 155 | return tokens 156 | 157 | def add_start_token(self, tokens: List[str]) -> List[str]: 158 | return [self.start_token] + tokens 159 | 160 | def add_end_token(self, tokens: List[str]) -> List[str]: 161 | return tokens + [self.end_token] 162 | 163 | def convert_token_to_id(self, token: str) -> int: 164 | try: 165 | return self.vocab[token] 166 | except BaseException: 167 | return self.vocab[SPEICIAL_TOKENS['unknown_token']] 168 | 169 | def convert_tokens_to_ids(self, tokens: List[str]) -> List[int]: 170 | return [self.convert_token_to_id(token) for token in tokens] 171 | 172 | def batch_encode( 173 | self, 174 | text: Union[List[str], np.ndarray], 175 | add_special_tokens: str = 'both', 176 | max_length: Optional[int] = None, 177 | return_tensors: Optional[bool] = True, 178 | ) -> List[List[int]]: 179 | if return_tensors: 180 | encoded = [self.encode(t, 181 | add_special_tokens=add_special_tokens, 182 | max_length=max_length, 183 | return_tensors=False) 184 | for t in text] 185 | inputs = [] 186 | lengths = [] 187 | for e in encoded: 188 | inputs.append(torch.LongTensor(e)) 189 | lengths.append(len(e)) 190 | 191 | inputs = rnn_utils.pad_sequence(inputs, 192 | batch_first=True, 193 | padding_value=self.padding_value) 194 | lengths = torch.LongTensor(lengths) 195 | 196 | return {'input_ids': inputs, 197 | 'lengths': lengths} 198 | else: 199 | return [self.encode(t, 200 | add_special_tokens=add_special_tokens, 201 | max_length=max_length, 202 | return_tensors=return_tensors) 203 | for t in text] 204 | 205 | def decode( 206 | self, 207 | token_ids: Union[List[int], np.ndarray, torch.Tensor], 208 | skip_special_tokens: bool = False 209 | ) -> str: 210 | assert isinstance(token_ids, (list, tuple)) or isinstance(token_ids, np.ndarray) or isinstance(token_ids, torch.Tensor) 211 | if not isinstance(token_ids, (list, tuple)): 212 | assert len(token_ids.shape) == 1, 'Available only 1D array for decoding' 213 | if isinstance(token_ids, torch.Tensor): 214 | assert token_ids.dtype in [torch.int, torch.long] 215 | token_ids = token_ids.tolist() 216 | 217 | if skip_special_tokens: 218 | SPEICIAL_TOKENS_VALUE = SPEICIAL_TOKENS.values() 219 | decoded = [self.ids_to_tokens[index] for index in token_ids if not self.ids_to_tokens[index] in SPEICIAL_TOKENS_VALUE] 220 | return ''.join(decoded) 221 | else: 222 | decoded = [self.ids_to_tokens[index] for index in token_ids] 223 | return ''.join(decoded) 224 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Molecular generation for Lectures on AI-driven Drug Discovery 2 | 3 | * **이 저장소는 한국제약바이오협회 인공지능신약개발지원센터에서 제공하는 4 | Lectures on AI-driven Drug Discovery(LAIDD) 강의의 일환으로 제작되었습니다.** 5 | 6 | 7 | 신약개발에서 원하는 특성을 가진 새로운 화합물을 찾아내는 것은 8 | 중요하고도 어려운 작업 입니다. 9 | 합성 가능한 분자의 공간은 10 | 11 | 으로 엄청나게 방대하여 12 | 이렇게 큰 공간에서 원하는 화합물을 찾아내는 것은 매우 어렵습니다. 13 | 그렇기 때문에 마치 넓은 바다의 모래사장에서 바늘 찾기라는 비유를 하기도 합니다. 14 | 15 | 최근 생성모델(Deep generative models)의 급속한 발전으로 16 | 이미지 데이터, 텍스트 데이터, 그래프 데이터 등 다양한 데이터들을 17 | 실제와 비슷하게 만들어 내는데 성공하고 있습니다. 18 | ML(Machine Leanring) field에서 만들어진 생성모델들을 Chemisty field에 19 | 적용하여 de novo generation 분야에서 좋은 결과를 만들어 냈습니다. 20 | 21 | 이 강의는 de novo molecular generation 모델의 기본 모델이라 할 수 있는 22 | RNN(Recurrent Neural Networks) 모델과 VAE(Variational AutoEncoders) 모델을 23 | 알아보고 직접 구현해보는 것을 목표로 합니다. 24 | 두 모델 모두 SMILES 데이터를 기반으로 만들어졌습니다. 25 | SMILES(Simplified molecular-input line-entry system)는 화합물을 특정 문법에 의거하여 26 | text(ASCII code) 기반 seqeunce 데이터로 표현하는 방법입니다. 27 | RNN모델은 이러한 seqeunce 데이터를 잘 다룰 수 있어서 SMILES기반 생성모델 28 | 또는 QSAR모델등에 사용될 수 있습니다. 29 | 30 | * Final update: 2023. 8. 22. 31 | * All right reserved @ 이일구 (Il Gu Yi) 2021 32 | * 이 저장소는 Ubuntu, Linux Mint 및 MacOS에서 테스트 되었습니다. 33 | * Windows는 별도의 테스트를 하지 못하였으나 가상환경 및 패키지설치가 된다면 사용가능할 것으로 생각됩니다. 34 | * Lectures on AI-driven Drug Discovery(LAIDD) 사이트: [https://www.laidd.org](https://www.laidd.org) 35 | 36 | 37 | ## Getting Started 38 | 39 | ### Prerequisites 40 | 41 | * `python` >= 3.7 42 | * [`pytorch`](https://pytorch.org) >= 1.7 43 | * `numpy`, `pandas`, `matplotlib` 44 | * `jupyter`, `easydict` 45 | * `rdkit` 46 | * 이 패키지를 설치할때 `rdkit`은 자동으로 설치되지 않아서 따로 설치를 해야 합니다. 47 | * `rdkit` install manual: [https://www.rdkit.org/docs/Install.html](https://www.rdkit.org/docs/Install.html) 48 | 49 | ##### `rdkit` 설치 방법 (`conda`이용 하여 설치를 추천) 50 | ```bash 51 | $ conda install -c conda-forge rdkit=2021.03.1 52 | ``` 53 | 54 | 55 | ## Installation 56 | 57 | ### 가상환경 만들기 58 | 59 | 이 패키지는 [`anaconda`](https://anaconda.org/) 환경에서 실행하는 것을 추천합니다. 60 | 먼저 `conda`를 이용하여 가상환경을 만듭니다. 61 | ```bash 62 | $ conda create --name laiddmg python=3.7 63 | $ conda activate laiddmg 64 | ``` 65 | 66 | `git clone`을 통해 이 패키지를 다운 받습니다. 67 | 그 후 `pip install .`으로 패키지 설치를 합니다. 68 | ```bash 69 | $ git clone https://github.com/ilguyi/LAIDD-molecular-generation.git 70 | $ cd LAIDD-molecular-generation 71 | $ pip install . 72 | $ conda install -c conda-forge rdkit=2021.03.1 # 패키지 설치시 rdkit은 설치되지 않아 따로 설치해야 합니다. 73 | ``` 74 | 75 | 76 | ## Quickstart 77 | 78 | ### Jupyter notebook 79 | 80 | 간단하게 모든 과정을 하나씩 실행해 볼 수 있게 81 | jupyter notebook 형태의 파일을 준비했습니다. 82 | [`jupyter_char_rnn.ipynb`](https://github.com/ilguyi/LAIDD-moleculra-generation/blob/main/laiddmg/jupyter_char_rnn.ipynb), 83 | [`jupyter_vae.ipynb`](https://github.com/ilguyi/LAIDD-moleculra-generation/blob/main/laiddmg/jupyter_vae.ipynb) 84 | 파일은 각각 CharRNN모델, ChemicalVAE모델을 실행할 수 있습니다. 85 | Jupyter 파일 역시 이 저장소를 설치해야 이용할 수 있습니다. 86 | 87 | 88 | ### Command execution 89 | 90 | 이 github 저장소를 clone 받고 패키지를 설치하면 두가지 command가 생성됩니다. 91 | * `laiddmg-train` command: 학습 데이터를 받아 각 모델(`CharRNN`, `ChemicalVAE`)들을 학습시키는 명령어입니다. 92 | * `laiddmg-generate` command: 최종 학습된 모델을 불러와 새로운 분자를 생성하는 명령어입니다. 93 | 94 | #### Training 95 | 96 | * 스크립트 파일을 직접 참고하시면 됩니다. 97 | * [`train.char_rnn.sh`](https://github.com/ilguyi/LAIDD-moleculra-generation/blob/main/laiddmg/train.char_rnn.sh) 98 | * [`train.vae.sh`](https://github.com/ilguyi/LAIDD-moleculra-generation/blob/main/laiddmg/train.vae.sh) 99 | 100 | ```bash 101 | #!/bin/bash 102 | 103 | laiddmg-train char_rnn --seed 219 \ 104 | --output_dir exp1 \ 105 | --dataset_path ../datasets \ 106 | --log_steps 10 \ 107 | --num_train_epochs 10 \ 108 | --train_batch_size 128 \ 109 | --[model_depend_arguments] \ 110 | ... 111 | ``` 112 | 113 | * `seed`: 재현성을 위한 random seed number 114 | * `output_dir`: output directory 경로 115 | * `dataset_path`: dataset 경로 116 | * `log_steps`: logging 하는 주기 (step 단위) 117 | * `num_train_epochs`: 최대 학습 epoch 118 | * `train_batch_size`: 학습 배치 사이즈 119 | * `model_depend_arguments`: 모델에 따라 다른 training arguments 120 | 121 | #### Generating 122 | 123 | * generate 스크립트를 직접 참고 하시면 됩니다. 124 | * [`generate.sh`](https://github.com/ilguyi/LAIDD-moleculra-generation/blob/main/laiddmg/generate.sh) 125 | 126 | ```bash 127 | #!/bin/bash 128 | 129 | laiddmg-generate char_rnn --seed 219 \ 130 | --checkpoint_dir outputs/char_rnn/exp1 \ 131 | --weights_name ckpt_100.pt \ 132 | --num_generation 10000 \ 133 | --batch_size_for_generation 256 134 | ``` 135 | 136 | * `seed`: 재현성을 위한 random seed number 137 | * `checkpoint_dir`: load할 weights 파일이 있는 ckeckpoint directory 경로 138 | * `weights_name`: load할 weights 파일 이름 139 | * `num_generation`: 생성할 SMILES 갯수 140 | * `batch_size_for_generation`: 생성할 때 배치 사이즈 141 | 142 | 143 | ### Simple python code 144 | 145 | #### Training 146 | 147 | ```python 148 | >>> from laiddmg import Tokenizer, CharRNNConfig, CharRNNModel, TRAINER_MAPPING 149 | 150 | >>> model_type = 'char_rnn' 151 | >>> tokenizer = Tokenizer() 152 | >>> config = CharRNNConfig() 153 | >>> model = CharRNNModel(config) 154 | 155 | >>> inputs = tokenizer('c1ccccc1') 156 | >>> outputs = model(**inputs) 157 | 158 | >>> train = get_rawdataset('train') 159 | >>> train_dataset = get_dataset(train, tokenizer) 160 | 161 | >>> trainer = TRAINER_MAPPING[model_type] 162 | >>> t = trainer(model=model, 163 | ... args=args, 164 | ... train_dataset=train_dataset, 165 | ... tokenizer=tokenizer) 166 | >>> t.train() 167 | ``` 168 | 169 | #### Generating 170 | 171 | ```python 172 | >>> model.eval() 173 | >>> outputs = model.generate(tokenizer=tokenizer, 174 | ... max_length=128, 175 | ... num_return_sequence=256, 176 | ... skip_special_tokens=True) 177 | >>> print(outputs) 178 | ``` 179 | 180 | ## Dataset 181 | 182 | 이 저장소에서는 183 | molecular generation 분야에서 대표적인 벤치마크 셋인 184 | [Molecular Sets (MOSES)](https://github.com/molecularsets/moses) 185 | 데이터 셋을 이용합니다. 186 | 이 데이터 셋은 [ZINC](https://zinc.docking.org/) 데이터 셋을 기본으로하여 187 | [몇가지 규칙](https://github.com/molecularsets/moses#dataset)에 의거하여 필터링한 데이터 셋입니다. 188 | 원래는 총 190만개의 SMILES데이터로 구성되어 있지만 189 | 이 저장소에서는 MOSES 데이터 셋에서 random 샘플링을 하여 `train set`:`test set`의 갯수를 190 | 각각 25만, 3만으로 만들었습니다. 191 | 나중에 MOSES dataset을 이용하여 트레이닝 해보시는 것을 추천 드립니다. 192 | 193 | 194 | ## Model architectures 195 | 196 | 1. **[CharRNN]**: Charter-level recurrent neural networks / [Generating Focused Molecule Libraries for Drug Discovery with Recurrent Neural Networks](https://pubs.acs.org/doi/10.1021/acscentsci.7b00512), by Marwin H. S. Segler, Thierry Kogej, Christian Tyrchan, and Mark P. Waller. 197 | 1. **[ChemicalVAE]**: Variational autoencoders / [Automatic Chemical Design Using a Data-Driven Continuous Representation of Molecules](https://pubs.acs.org/doi/10.1021/acscentsci.7b00572), by Rafael Gómez-Bombarelli, Jennifer N. Wei, David Duvenaud, José Miguel Hernández-Lobato, Benjamín Sánchez-Lengeling, Dennis Sheberla, Jorge Aguilera-Iparraguirre, Timothy D. Hirzel, Ryan P. Adams, and Alán Aspuru-Guzik. 198 | 199 | 200 | ## Lecture notes 201 | 202 | * [Lecture 1](https://www.dropbox.com/s/um8oukzoqlioff6/molecule%20generation%201.pdf?dl=0) 203 | * [Lecture 2](https://www.dropbox.com/s/okpnpjx2wmzioyo/molecule%20generation%202.pdf?dl=0) 204 | 205 | 206 | ## Author 207 | 208 | * 이일구 (Il Gu Yi) 209 | * e-mail: ilgu.yi.work@gmail.com 210 | -------------------------------------------------------------------------------- /laiddmg/models/vae/modeling.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple, Union 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.nn.utils.rnn as rnn_utils 7 | 8 | from .configuration import VAEConfig 9 | from ...tokenization_utils import Tokenizer 10 | from ...modeling_utils import BaseModel 11 | from ... import logging_utils 12 | 13 | logger = logging_utils.get_logger(__name__) 14 | 15 | 16 | class Encoder(nn.Module): 17 | 18 | def __init__(self, config: VAEConfig, embeddings: nn.Module = None): 19 | super(Encoder, self).__init__() 20 | self.vocab_size = config.vocab_size 21 | self.embedding_dim = config.embedding_dim 22 | self.encoder_hidden_dim = config.encoder_hidden_dim 23 | self.encoder_num_layers = config.encoder_num_layers 24 | self.encoder_dropout = config.encoder_dropout 25 | self.latent_dim = config.latent_dim 26 | self.padding_value = config.padding_value 27 | 28 | if embeddings is not None: 29 | self.embeddings = embeddings 30 | else: 31 | self.embeddings = nn.Embedding(self.vocab_size, self.embedding_dim, 32 | padding_idx=self.padding_value) 33 | 34 | self.gru = nn.GRU(self.embedding_dim, 35 | self.encoder_hidden_dim, 36 | self.encoder_num_layers, 37 | batch_first=True, 38 | dropout=self.encoder_dropout if self.encoder_num_layers > 1 else 0) 39 | self.fc = nn.Linear(self.encoder_hidden_dim, self.latent_dim * 2) 40 | 41 | def forward( 42 | self, 43 | input_ids: torch.Tensor, # (batch_size, seq_len) 44 | lengths: torch.Tensor, # (batch_size,) 45 | **kwargs, 46 | ) -> Tuple[torch.Tensor]: 47 | x = self.embeddings(input_ids) # x: (batch_size, seq_len, embedding_dim) 48 | x = rnn_utils.pack_padded_sequence( 49 | x, 50 | lengths.cpu(), 51 | batch_first=True, 52 | enforce_sorted=False, 53 | ) 54 | _, hiddens = self.gru(x, None) # hiddens: (num_layers, batch_size, encoder_hidden_dim) 55 | 56 | hiddens = hiddens[-1] # hiddens: (batch_size, encoder_hidden_dim) for last layer 57 | 58 | z_mu, z_logvar = torch.split(self.fc(hiddens), self.latent_dim, dim=-1) 59 | # z_mu, z_logvar: (batch_size, latent_dim) 60 | 61 | return z_mu, z_logvar 62 | 63 | 64 | class Decoder(nn.Module): 65 | 66 | def __init__(self, config: VAEConfig, embeddings: nn.Module = None): 67 | super(Decoder, self).__init__() 68 | self.vocab_size = config.vocab_size 69 | self.embedding_dim = config.embedding_dim 70 | self.latent_dim = config.latent_dim 71 | self.decoder_hidden_dim = config.decoder_hidden_dim 72 | self.decoder_num_layers = config.decoder_num_layers 73 | self.decoder_dropout = config.decoder_dropout 74 | self.input_dim = self.embedding_dim + self.latent_dim 75 | self.output_dim = config.vocab_size 76 | self.padding_value = config.padding_value 77 | 78 | if embeddings is not None: 79 | self.embeddings = embeddings 80 | else: 81 | self.embeddings = nn.Embedding(self.vocab_size, self.embedding_dim, 82 | padding_idx=self.padding_value) 83 | 84 | self.gru = nn.GRU(self.input_dim, 85 | self.decoder_hidden_dim, 86 | self.decoder_num_layers, 87 | batch_first=True, 88 | dropout=self.decoder_dropout if self.decoder_num_layers > 1 else 0) 89 | self.z2hidden = nn.Linear(self.latent_dim, self.decoder_hidden_dim) 90 | self.fc = nn.Linear(self.decoder_hidden_dim, self.output_dim) 91 | 92 | def forward( 93 | self, 94 | input_ids: torch.Tensor, # (batch_size, seq_len) 95 | lengths: torch.Tensor, # (batch_size,) 96 | z: torch.Tensor, # (batch_size, latent_dim) 97 | **kwargs, 98 | ) -> Tuple[torch.Tensor]: 99 | x = self.embeddings(input_ids) # x: (batch_size, seq_len, embedding_dim) 100 | hiddens = self.z2hidden(z) # hiddens: (batch_size, decoder_hidden_dim) 101 | hiddens = hiddens.unsqueeze(0).repeat(self.decoder_num_layers, 1, 1) 102 | # hiddens: (num_layers, batch_size, decoder_hidden_dim) 103 | 104 | z_ = z.unsqueeze(1).repeat(1, x.shape[1], 1) # z: (batch_size, seq_len, latent_dim) 105 | x = torch.cat((x, z_), dim=-1) # x: (batch_size, seq_len, embedding_dim + latent_dim) 106 | 107 | x = rnn_utils.pack_padded_sequence( 108 | x, 109 | lengths.cpu(), 110 | batch_first=True, 111 | enforce_sorted=False 112 | ) 113 | x, _ = self.gru(x, hiddens) 114 | x, _ = rnn_utils.pad_packed_sequence( 115 | x, 116 | batch_first=True, 117 | ) # x: (batch_size, seq_len, hidden_dim) 118 | outputs = self.fc(x) # outputs: (batch_size, seq_len, vocab_size) 119 | 120 | return outputs 121 | 122 | 123 | class VAEModel(BaseModel): 124 | 125 | def __init__(self, config: VAEConfig): 126 | super(VAEModel, self).__init__() 127 | self.config = config 128 | self.vocab_size = config.vocab_size 129 | self.embedding_dim = config.embedding_dim 130 | self.latent_dim = config.latent_dim 131 | self.padding_value = config.padding_value 132 | 133 | self.embeddings = nn.Embedding(self.vocab_size, self.embedding_dim, 134 | padding_idx=self.padding_value) 135 | 136 | self.encoder = Encoder(self.config, self.embeddings) 137 | self.decoder = Decoder(self.config, self.embeddings) 138 | 139 | def reparameterize(self, mean, logvar): 140 | epsilon = torch.rand_like(mean) 141 | z = epsilon * torch.exp(logvar * .5) + mean # mean, logvar, z: (batch_size, latent_dim) 142 | 143 | return z 144 | 145 | def forward( 146 | self, 147 | input_ids: torch.Tensor, # (batch_size, seq_len) 148 | lengths: torch.Tensor, # (batch_size,) 149 | **kwargs, 150 | ) -> Tuple[torch.Tensor]: 151 | z_mu, z_logvar = self.encoder(input_ids, lengths) 152 | z = self.reparameterize(z_mu, z_logvar) # z: (batch_size, latent_dim) 153 | y = self.decoder(input_ids, lengths, z) # y: (batch_size, seq_len, vocab_size) 154 | 155 | return y, z_mu, z_logvar 156 | 157 | def sample_gaussian_dist(self, batch_size: int): 158 | return torch.randn(batch_size, self.latent_dim).to(self.device) 159 | 160 | @torch.no_grad() 161 | def generate( 162 | self, 163 | tokenizer: Tokenizer = None, 164 | max_length: int = 128, 165 | num_return_sequences: int = 1, 166 | skip_special_tokens: bool = False, 167 | **kwargs, 168 | ) -> Union[List[List[int]], List[List[str]]]: 169 | 170 | z = kwargs.pop('z', None) 171 | z = z if z is not None else self.sample_gaussian_dist(num_return_sequences) 172 | assert z.shape == (num_return_sequences, self.latent_dim) # z: [batch_size, latent_dim] 173 | z_ = z.unsqueeze(1) # z_: [batch_size, 1, latent_dim] 174 | 175 | initial_inputs = torch.full((num_return_sequences, 1), 176 | tokenizer.convert_token_to_id(tokenizer.start_token), 177 | dtype=torch.long, 178 | device=self.device) 179 | generated_sequences = initial_inputs 180 | input_ids = initial_inputs # input_ids: [batch_size, 1] 181 | 182 | # z -> initial hiddens 183 | hiddens = self.decoder.z2hidden(z) # hiddens: [batch_size, hidden_dim] 184 | hiddens = hiddens.unsqueeze(0).repeat(self.config.decoder_num_layers, 1, 1) 185 | # hiddens: [decoder_num_layers, batch_size, hidden_dim] 186 | 187 | for i in range(max_length + 1): 188 | x = self.decoder.embeddings(input_ids) # x: [batch_size, 1, embedding_dim] 189 | x = torch.cat((x, z_), dim=-1) # x: [batch_size, 1, embedding_dim + latent_dim] 190 | x, hiddens = self.decoder.gru(x, hiddens) # x: [batch_size, 1, hidden_dim] 191 | logits = self.decoder.fc(x) # logits: [batch_size, 1, vocab_size] 192 | next_token_logits = logits.squeeze(1) # next_token_logits: [batch_size, vocab_size] 193 | 194 | probabilities = F.softmax(next_token_logits, dim=-1) # probabilities: [batch_size, vocab_size] 195 | next_tokens = torch.multinomial(probabilities, num_samples=1) 196 | # next_tokens: [batch_size, 1] 197 | 198 | input_ids = next_tokens 199 | generated_sequences = torch.cat([generated_sequences, next_tokens], dim=1) 200 | 201 | generated_sequences = self.postprocessing(generated_sequences, tokenizer) 202 | 203 | generated_SMILES = [] 204 | for sequence in generated_sequences: 205 | generated_SMILES.append(tokenizer.decode(sequence, skip_special_tokens)) 206 | 207 | return generated_SMILES 208 | -------------------------------------------------------------------------------- /laiddmg/jupyter_char_rnn.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Character RNN(CharRNN) 모델 설명 및 학습과 생성" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": { 14 | "ExecuteTime": { 15 | "end_time": "2021-08-30T04:41:02.273417Z", 16 | "start_time": "2021-08-30T04:41:01.576756Z" 17 | } 18 | }, 19 | "outputs": [], 20 | "source": [ 21 | "import os\n", 22 | "from easydict import EasyDict\n", 23 | "from typing import List, Tuple, Dict, Union\n", 24 | "\n", 25 | "from laiddmg import (\n", 26 | " CharRNNConfig,\n", 27 | " Tokenizer,\n", 28 | " CharRNNModel,\n", 29 | " get_rawdataset,\n", 30 | " get_dataset,\n", 31 | ")\n", 32 | "\n", 33 | "import torch\n", 34 | "import torch.nn as nn\n", 35 | "import torch.nn.functional as F\n", 36 | "import torch.nn.utils.rnn as rnn_utils\n", 37 | "import torch.optim as optim" 38 | ] 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "metadata": { 43 | "ExecuteTime": { 44 | "end_time": "2021-08-18T01:17:37.048161Z", 45 | "start_time": "2021-08-18T01:17:37.045520Z" 46 | } 47 | }, 48 | "source": [ 49 | "## configuration, tokenizer, model 생성\n", 50 | "\n", 51 | "* `CharRNNConfig` class:\n", 52 | " * 모델을 구성하기 위해 필요한 정보(`hidden_dim`, `num_layers` 등)들이 담긴 class입니다.\n", 53 | " * 자세한 코드는 [`laiddmg/models/char_rnn/configuration.py`](https://github.com/ilguyi/LAIDD-molecular-generation/blob/main/laiddmg/models/char_rnn/configuration.py)에 나와 있습니다.\n", 54 | "* `Tokenizer` class:\n", 55 | " * `str`으로 된 SMILES 데이터를 미리 정의해둔 `vocab_dict`에 맞춰 token data(`int`)로 바꿔주는 역할을 합니다.\n", 56 | " * 자세한 코드는 [`laiddmg/tokenization_utils.py`](https://github.com/ilguyi/LAIDD-molecular-generation/blob/main/laiddmg/tokenization_utils.py)에 나와 있습니다.\n", 57 | "* `CharRNNModel` class:\n", 58 | " * 실제 모델을 만들어주는 클래스입니다.\n", 59 | " * `PyTorch`에서 제공하는 표준적인 방법으로 클래스를 구성하였습니다. tutorial은 [https://pytorch.org/tutorials/beginner/basics/buildmodel_tutorial.html](https://pytorch.org/tutorials/beginner/basics/buildmodel_tutorial.html) 여기서 확인할 수 있습니다.\n", 60 | " * 이 모델은 Marwin H. S. Segler, et. al., [Generating Focused Molecule Libraries for Drug Discovery with Recurrent Neural Networks](https://pubs.acs.org/doi/10.1021/acscentsci.7b00512)을 바탕으로 작성하였습니다.\n", 61 | " * 자세한 코드는 [`laiddmg/models/char_rnn/modeling.py`](https://github.com/ilguyi/LAIDD-molecular-generation/blob/main/laiddmg/models/char_rnn/modeling.py)에 나와 있습니다." 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": null, 67 | "metadata": { 68 | "ExecuteTime": { 69 | "end_time": "2021-08-30T04:41:06.536384Z", 70 | "start_time": "2021-08-30T04:41:06.437120Z" 71 | } 72 | }, 73 | "outputs": [], 74 | "source": [ 75 | "model_type = 'char_rnn'\n", 76 | "config = CharRNNConfig()\n", 77 | "tokenizer = Tokenizer()\n", 78 | "model = CharRNNModel(config)" 79 | ] 80 | }, 81 | { 82 | "cell_type": "markdown", 83 | "metadata": {}, 84 | "source": [ 85 | "#### Print model configuration" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": null, 91 | "metadata": { 92 | "ExecuteTime": { 93 | "end_time": "2021-08-29T06:39:51.464267Z", 94 | "start_time": "2021-08-29T06:39:51.458388Z" 95 | } 96 | }, 97 | "outputs": [], 98 | "source": [ 99 | "for k, v in config.__dict__.items():\n", 100 | " print(f'{k}: {v}')" 101 | ] 102 | }, 103 | { 104 | "cell_type": "markdown", 105 | "metadata": { 106 | "ExecuteTime": { 107 | "end_time": "2021-08-18T01:20:33.876741Z", 108 | "start_time": "2021-08-18T01:20:33.874348Z" 109 | } 110 | }, 111 | "source": [ 112 | "#### How to use tokenizer" 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": null, 118 | "metadata": { 119 | "ExecuteTime": { 120 | "end_time": "2021-08-29T06:39:52.724103Z", 121 | "start_time": "2021-08-29T06:39:52.709519Z" 122 | } 123 | }, 124 | "outputs": [], 125 | "source": [ 126 | "tokenizer.vocab" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": null, 132 | "metadata": { 133 | "ExecuteTime": { 134 | "end_time": "2021-08-29T06:39:52.868540Z", 135 | "start_time": "2021-08-29T06:39:52.863541Z" 136 | } 137 | }, 138 | "outputs": [], 139 | "source": [ 140 | "smiles = 'c1ccccc1' # 벤젠\n", 141 | "tokenizer(smiles)" 142 | ] 143 | }, 144 | { 145 | "cell_type": "markdown", 146 | "metadata": { 147 | "ExecuteTime": { 148 | "end_time": "2021-08-18T01:22:07.242192Z", 149 | "start_time": "2021-08-18T01:22:07.239650Z" 150 | } 151 | }, 152 | "source": [ 153 | "#### Print model's informations" 154 | ] 155 | }, 156 | { 157 | "cell_type": "code", 158 | "execution_count": null, 159 | "metadata": { 160 | "ExecuteTime": { 161 | "end_time": "2021-08-29T06:39:53.813828Z", 162 | "start_time": "2021-08-29T06:39:53.809025Z" 163 | } 164 | }, 165 | "outputs": [], 166 | "source": [ 167 | "model" 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": null, 173 | "metadata": { 174 | "ExecuteTime": { 175 | "end_time": "2021-08-29T06:39:54.319575Z", 176 | "start_time": "2021-08-29T06:39:54.314478Z" 177 | } 178 | }, 179 | "outputs": [], 180 | "source": [ 181 | "print(f'model type: {config.model_type}')\n", 182 | "print(f'model device: {model.device}')\n", 183 | "print(f'model dtype: {model.dtype}')\n", 184 | "print(f'number of training parameters: {model.num_parameters()}')" 185 | ] 186 | }, 187 | { 188 | "cell_type": "markdown", 189 | "metadata": {}, 190 | "source": [ 191 | "### RNN Model class\n", 192 | "\n", 193 | "* 자세한 코드는 [`laiddmg/models/char_rnn/modeling.py`](https://github.com/ilguyi/LAIDD-molecular-generation/blob/main/laiddmg/models/char_rnn/modeling.py)에 나와 있습니다." 194 | ] 195 | }, 196 | { 197 | "cell_type": "code", 198 | "execution_count": null, 199 | "metadata": { 200 | "ExecuteTime": { 201 | "end_time": "2021-08-29T06:39:58.361731Z", 202 | "start_time": "2021-08-29T06:39:58.350272Z" 203 | } 204 | }, 205 | "outputs": [], 206 | "source": [ 207 | "class _CharRNNModel(nn.Module):\n", 208 | "\n", 209 | " def __init__(self, config: CharRNNConfig):\n", 210 | " super(CharRNNModel, self).__init__()\n", 211 | " self.config = config\n", 212 | " self.vocab_size = config.vocab_size\n", 213 | " self.embedding_dim = config.embedding_dim\n", 214 | " self.hidden_dim = config.hidden_dim\n", 215 | " self.num_layers = config.num_layers\n", 216 | " self.dropout = config.dropout\n", 217 | " self.padding_value = config.padding_value\n", 218 | " self.output_dim = self.vocab_size\n", 219 | "\n", 220 | " self.embeddings = nn.Embedding(self.vocab_size, self.embedding_dim,\n", 221 | " padding_idx=self.padding_value)\n", 222 | " self.lstm = nn.LSTM(self.embedding_dim, self.hidden_dim,\n", 223 | " self.num_layers,\n", 224 | " batch_first=True,\n", 225 | " dropout=self.dropout)\n", 226 | " self.fc = nn.Linear(self.hidden_dim, self.output_dim)\n", 227 | "\n", 228 | " def forward(\n", 229 | " self,\n", 230 | " input_ids: torch.Tensor, # (batch_size, seq_len)\n", 231 | " lengths: torch.Tensor, # (batch_size,)\n", 232 | " hiddens: Tuple[torch.Tensor] = None, # (num_layers, batch_size, hidden_dim)\n", 233 | " **kwargs,\n", 234 | " ) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:\n", 235 | " x = self.embeddings(input_ids) # x: (batch_size, seq_len, embedding_dim)\n", 236 | " x = rnn_utils.pack_padded_sequence(\n", 237 | " x,\n", 238 | " lengths.cpu(),\n", 239 | " batch_first=True,\n", 240 | " enforce_sorted=False,\n", 241 | " )\n", 242 | " x, hiddens = self.lstm(x, hiddens)\n", 243 | " # hiddens: (h, c); (num_layers, batch_size, hidden_dim), respectively\n", 244 | " x, _ = rnn_utils.pad_packed_sequence(\n", 245 | " x,\n", 246 | " batch_first=True,\n", 247 | " ) # x: (batch_size, seq_len, hidden_dim)\n", 248 | " outputs = self.fc(x) # outputs: (batch_size, seq_len, vocab_size)\n", 249 | "\n", 250 | " return outputs, hiddens" 251 | ] 252 | }, 253 | { 254 | "cell_type": "markdown", 255 | "metadata": {}, 256 | "source": [ 257 | "### `pack_padded_sequence` 설명을 위한 token data\n", 258 | "\n", 259 | "다음과 같은 token data (출처: https://medium.com/huggingface/understanding-emotions-from-keras-to-pytorch-3ccb61d5a983) 가 있다고 생각해봅시다.\n", 260 | "총 5개의 데이터가 있고 각 문장(1개의 데이터)의 길이(한 문장의 token 갯수)는 다음과 같습니다.\n", 261 | "`lenghts = [6, 5, 2, 4, 1]`.\n", 262 | "시퀀스 길이가 서로 다르기 때문에 가장 긴 길이에 맞춰 `padding`을 해줍니다.\n", 263 | "![token_data](https://user-images.githubusercontent.com/11681225/129828808-e1e35cf2-1730-4e9d-b616-4426c11be1aa.png)\n", 264 | "\n", 265 | "다음과 같은 `vocab_dict`에 따라 `input_ids` tensor를 만들어줍니다.\n", 266 | "```python\n", 267 | "vocab_dict = {\n", 268 | " 'I': 1, 'Mom': 2, 'No': 3, 'This': 4, 'Yes': 5, 'cooking': 6, 'is': 7, 'love': 8,\n", 269 | " 's': 9, 'shit': 10, 'the': 11, 'too': 12, 'way': 13, 'you': 14, '!': 15, '`': 16,\n", 270 | "}\n", 271 | "```\n", 272 | "우리가 사용할 `input_ids` tensor는 아래와 같습니다." 273 | ] 274 | }, 275 | { 276 | "cell_type": "code", 277 | "execution_count": null, 278 | "metadata": { 279 | "ExecuteTime": { 280 | "end_time": "2021-08-29T06:40:00.254258Z", 281 | "start_time": "2021-08-29T06:40:00.249232Z" 282 | } 283 | }, 284 | "outputs": [], 285 | "source": [ 286 | "input_ids = torch.LongTensor([\n", 287 | " [ 1, 8, 2, 16, 9, 6],\n", 288 | " [ 1, 8, 14, 12, 15, 0],\n", 289 | " [ 3, 13, 0, 0, 0, 0],\n", 290 | " [ 4, 7, 11, 10, 0, 0],\n", 291 | " [ 5, 0, 0, 0, 0, 0]\n", 292 | "]) # input_ids: (batch_size, seq_len)\n", 293 | "lengths = [6, 5, 2, 4, 1]" 294 | ] 295 | }, 296 | { 297 | "cell_type": "markdown", 298 | "metadata": {}, 299 | "source": [ 300 | "우리가 사용할 embedding matrix를 눈으로 확인하기 쉽게 만들어줍니다." 301 | ] 302 | }, 303 | { 304 | "cell_type": "code", 305 | "execution_count": null, 306 | "metadata": { 307 | "ExecuteTime": { 308 | "end_time": "2021-08-29T06:40:01.119206Z", 309 | "start_time": "2021-08-29T06:40:01.114980Z" 310 | } 311 | }, 312 | "outputs": [], 313 | "source": [ 314 | "vocab_size = 16 + 1 # `+ 1`: 1을 더해주는 이유는 padding index(0)를 추가하기 때문입니다.\n", 315 | "embedding_dim = 1\n", 316 | "embeddings = nn.Embedding.from_pretrained(torch.arange(vocab_size, dtype=torch.float).unsqueeze(1))" 317 | ] 318 | }, 319 | { 320 | "cell_type": "code", 321 | "execution_count": null, 322 | "metadata": { 323 | "ExecuteTime": { 324 | "end_time": "2021-08-29T06:40:01.438971Z", 325 | "start_time": "2021-08-29T06:40:01.432542Z" 326 | } 327 | }, 328 | "outputs": [], 329 | "source": [ 330 | "embeddings.weight" 331 | ] 332 | }, 333 | { 334 | "cell_type": "markdown", 335 | "metadata": { 336 | "ExecuteTime": { 337 | "end_time": "2021-08-18T05:48:14.629660Z", 338 | "start_time": "2021-08-18T05:48:14.626555Z" 339 | } 340 | }, 341 | "source": [ 342 | "### `pack_padded_sequence` 적용\n", 343 | "\n", 344 | "* `pack_padded_sequence`의 `PyTorch` 예제는 다음 링크에 있습니다. [https://pytorch.org/tutorials/beginner/chatbot_tutorial.html#encoder](https://pytorch.org/tutorials/beginner/chatbot_tutorial.html#encoder)\n", 345 | "* 함수 설명은 다음 링크에 있습니다. [https://pytorch.org/docs/stable/generated/torch.nn.utils.rnn.pack_padded_sequence.html](https://pytorch.org/docs/stable/generated/torch.nn.utils.rnn.pack_padded_sequence.html)" 346 | ] 347 | }, 348 | { 349 | "cell_type": "code", 350 | "execution_count": null, 351 | "metadata": { 352 | "ExecuteTime": { 353 | "end_time": "2021-08-29T06:40:02.884245Z", 354 | "start_time": "2021-08-29T06:40:02.879749Z" 355 | } 356 | }, 357 | "outputs": [], 358 | "source": [ 359 | "x = embeddings(input_ids) # x: (batch_size, seq_len, embedding_dim)\n", 360 | "packed_x = rnn_utils.pack_padded_sequence(\n", 361 | " x,\n", 362 | " lengths,\n", 363 | " batch_first=True,\n", 364 | " enforce_sorted=False,\n", 365 | ")" 366 | ] 367 | }, 368 | { 369 | "cell_type": "code", 370 | "execution_count": null, 371 | "metadata": { 372 | "ExecuteTime": { 373 | "end_time": "2021-08-29T06:40:03.013097Z", 374 | "start_time": "2021-08-29T06:40:03.007976Z" 375 | } 376 | }, 377 | "outputs": [], 378 | "source": [ 379 | "x" 380 | ] 381 | }, 382 | { 383 | "cell_type": "markdown", 384 | "metadata": { 385 | "ExecuteTime": { 386 | "end_time": "2021-08-18T04:04:39.636584Z", 387 | "start_time": "2021-08-18T04:04:39.631237Z" 388 | } 389 | }, 390 | "source": [ 391 | "참고\n", 392 | "```\n", 393 | "input_ids = torch.LongTensor([\n", 394 | " [ 1, 8, 2, 16, 9, 6],\n", 395 | " [ 1, 8, 14, 12, 15, 0],\n", 396 | " [ 3, 13, 0, 0, 0, 0],\n", 397 | " [ 4, 7, 11, 10, 0, 0],\n", 398 | " [ 5, 0, 0, 0, 0, 0]\n", 399 | "])\n", 400 | "```" 401 | ] 402 | }, 403 | { 404 | "cell_type": "code", 405 | "execution_count": null, 406 | "metadata": { 407 | "ExecuteTime": { 408 | "end_time": "2021-08-29T06:40:03.661611Z", 409 | "start_time": "2021-08-29T06:40:03.655145Z" 410 | } 411 | }, 412 | "outputs": [], 413 | "source": [ 414 | "packed_x" 415 | ] 416 | }, 417 | { 418 | "cell_type": "markdown", 419 | "metadata": { 420 | "ExecuteTime": { 421 | "end_time": "2021-08-18T04:06:24.051291Z", 422 | "start_time": "2021-08-18T04:06:24.045586Z" 423 | } 424 | }, 425 | "source": [ 426 | "참고: `pack_padded_sequence`를 수행하면 시퀀스 길이에 따라 정렬을 하고 정렬된 데이터를 `pack`을 해준다.\n", 427 | "```python\n", 428 | "input_ids = torch.LongTensor([\n", 429 | " [ 1, 8, 2, 16, 9, 6],\n", 430 | " [ 1, 8, 14, 12, 15, 0],\n", 431 | " [ 4, 7, 11, 10, 0, 0], # 시퀀스 길이 순서에 따라 4번째 행이 3번째 행으로 올라감\n", 432 | " [ 3, 13, 0, 0, 0, 0], # 시퀀스 길이 순서에 따라 3번째 행이 4번째 행으로 내려감\n", 433 | " [ 5, 0, 0, 0, 0, 0]\n", 434 | "])\n", 435 | "ᅟbatch_sizes = [5, 4, 3, 3, 2, 1]\n", 436 | "```" 437 | ] 438 | }, 439 | { 440 | "cell_type": "markdown", 441 | "metadata": { 442 | "ExecuteTime": { 443 | "end_time": "2021-08-18T04:15:03.522340Z", 444 | "start_time": "2021-08-18T04:15:03.387400Z" 445 | } 446 | }, 447 | "source": [ 448 | "### `pack_padded_sequence` 이후의 모습\n", 449 | "\n", 450 | "![packed_sequence](https://user-images.githubusercontent.com/11681225/129835933-852b6add-2acc-493c-bfdd-7693d6cfe737.png)\n", 451 | "(출처: https://medium.com/huggingface/understanding-emotions-from-keras-to-pytorch-3ccb61d5a983)" 452 | ] 453 | }, 454 | { 455 | "cell_type": "markdown", 456 | "metadata": {}, 457 | "source": [ 458 | "### 간단한 RNN 모델 생성" 459 | ] 460 | }, 461 | { 462 | "cell_type": "code", 463 | "execution_count": null, 464 | "metadata": { 465 | "ExecuteTime": { 466 | "end_time": "2021-08-29T06:40:05.041117Z", 467 | "start_time": "2021-08-29T06:40:05.037420Z" 468 | } 469 | }, 470 | "outputs": [], 471 | "source": [ 472 | "rnn = nn.RNN(embedding_dim, 2,\n", 473 | " batch_first=True)" 474 | ] 475 | }, 476 | { 477 | "cell_type": "markdown", 478 | "metadata": { 479 | "ExecuteTime": { 480 | "end_time": "2021-08-18T05:49:35.167399Z", 481 | "start_time": "2021-08-18T05:49:35.164803Z" 482 | } 483 | }, 484 | "source": [ 485 | "### `packed_x`와 `x` (pack 하지 않은 데이터)의 rnn 결과물 비교" 486 | ] 487 | }, 488 | { 489 | "cell_type": "code", 490 | "execution_count": null, 491 | "metadata": { 492 | "ExecuteTime": { 493 | "end_time": "2021-08-29T06:40:05.602914Z", 494 | "start_time": "2021-08-29T06:40:05.598453Z" 495 | } 496 | }, 497 | "outputs": [], 498 | "source": [ 499 | "rnn_x, hiddens = rnn(packed_x)\n", 500 | "rnn_x1, hiddens1 = rnn(x)\n", 501 | "# hiddens: (h, c); (num_layers, batch_size, hidden_dim), respectively" 502 | ] 503 | }, 504 | { 505 | "cell_type": "code", 506 | "execution_count": null, 507 | "metadata": { 508 | "ExecuteTime": { 509 | "end_time": "2021-08-29T06:40:05.766216Z", 510 | "start_time": "2021-08-29T06:40:05.761125Z" 511 | } 512 | }, 513 | "outputs": [], 514 | "source": [ 515 | "hiddens1" 516 | ] 517 | }, 518 | { 519 | "cell_type": "code", 520 | "execution_count": null, 521 | "metadata": { 522 | "ExecuteTime": { 523 | "end_time": "2021-08-29T06:40:06.019543Z", 524 | "start_time": "2021-08-29T06:40:06.012997Z" 525 | } 526 | }, 527 | "outputs": [], 528 | "source": [ 529 | "rnn_x1" 530 | ] 531 | }, 532 | { 533 | "cell_type": "markdown", 534 | "metadata": { 535 | "ExecuteTime": { 536 | "end_time": "2021-08-18T04:31:47.402389Z", 537 | "start_time": "2021-08-18T04:31:47.399421Z" 538 | } 539 | }, 540 | "source": [ 541 | "#### `packed`된 데이터를 다시 원래 모습으로 바꿔주기 위해 `pad_packed_sequence`를 사용한다\n", 542 | "\n", 543 | "* `pad_packed_sequence`의 자세한 함수 설명은 다음 링크에서 확인할 수 있습니다. [https://pytorch.org/docs/stable/generated/torch.nn.utils.rnn.pad_packed_sequence.html](https://pytorch.org/docs/stable/generated/torch.nn.utils.rnn.pad_packed_sequence.html)" 544 | ] 545 | }, 546 | { 547 | "cell_type": "code", 548 | "execution_count": null, 549 | "metadata": { 550 | "ExecuteTime": { 551 | "end_time": "2021-08-29T06:40:06.360646Z", 552 | "start_time": "2021-08-29T06:40:06.354426Z" 553 | } 554 | }, 555 | "outputs": [], 556 | "source": [ 557 | "rnn_x" 558 | ] 559 | }, 560 | { 561 | "cell_type": "code", 562 | "execution_count": null, 563 | "metadata": { 564 | "ExecuteTime": { 565 | "end_time": "2021-08-29T06:40:06.865300Z", 566 | "start_time": "2021-08-29T06:40:06.854743Z" 567 | } 568 | }, 569 | "outputs": [], 570 | "source": [ 571 | "output_x, _ = rnn_utils.pad_packed_sequence(\n", 572 | " rnn_x,\n", 573 | " batch_first=True,\n", 574 | ") # x: (batch_size, seq_len, hidden_dim)" 575 | ] 576 | }, 577 | { 578 | "cell_type": "code", 579 | "execution_count": null, 580 | "metadata": { 581 | "ExecuteTime": { 582 | "end_time": "2021-08-29T06:40:06.962093Z", 583 | "start_time": "2021-08-29T06:40:06.954969Z" 584 | } 585 | }, 586 | "outputs": [], 587 | "source": [ 588 | "output_x" 589 | ] 590 | }, 591 | { 592 | "cell_type": "code", 593 | "execution_count": null, 594 | "metadata": { 595 | "ExecuteTime": { 596 | "end_time": "2021-08-29T06:40:07.058288Z", 597 | "start_time": "2021-08-29T06:40:07.052898Z" 598 | } 599 | }, 600 | "outputs": [], 601 | "source": [ 602 | "hiddens" 603 | ] 604 | }, 605 | { 606 | "cell_type": "markdown", 607 | "metadata": { 608 | "ExecuteTime": { 609 | "end_time": "2021-08-18T04:41:43.868416Z", 610 | "start_time": "2021-08-18T04:41:43.864883Z" 611 | } 612 | }, 613 | "source": [ 614 | "## Model 체크\n", 615 | "\n", 616 | "* model의 input으로는 `input_ids`와 `lengths`가 필요합니다. \n", 617 | "* `input_ids`는 `tokenizer`를 통해 SMILES character를 각각 token number로 바꾼 결과입니다.\n", 618 | "* `lengths`는 각 문장(각 SMILES 데이터)의 sequence length 정보입니다." 619 | ] 620 | }, 621 | { 622 | "cell_type": "code", 623 | "execution_count": null, 624 | "metadata": { 625 | "ExecuteTime": { 626 | "end_time": "2021-08-29T06:40:07.218231Z", 627 | "start_time": "2021-08-29T06:40:07.212467Z" 628 | } 629 | }, 630 | "outputs": [], 631 | "source": [ 632 | "smiles = 'c1ccccc1'\n", 633 | "inputs = tokenizer(smiles)\n", 634 | "inputs" 635 | ] 636 | }, 637 | { 638 | "cell_type": "code", 639 | "execution_count": null, 640 | "metadata": { 641 | "ExecuteTime": { 642 | "end_time": "2021-08-29T06:40:07.379145Z", 643 | "start_time": "2021-08-29T06:40:07.311016Z" 644 | } 645 | }, 646 | "outputs": [], 647 | "source": [ 648 | "# outputs, hiddens = model(**inputs)\n", 649 | "outputs, hiddens = model(input_ids=inputs['input_ids'],\n", 650 | " lengths=inputs['lengths'])" 651 | ] 652 | }, 653 | { 654 | "cell_type": "code", 655 | "execution_count": null, 656 | "metadata": { 657 | "ExecuteTime": { 658 | "end_time": "2021-08-29T06:40:07.470278Z", 659 | "start_time": "2021-08-29T06:40:07.466054Z" 660 | } 661 | }, 662 | "outputs": [], 663 | "source": [ 664 | "print(f'outputs shape: {outputs.shape}')\n", 665 | "print(f'hidden state shape: {hiddens[0].shape}')\n", 666 | "print(f'memory state shape: {hiddens[1].shape}')" 667 | ] 668 | }, 669 | { 670 | "cell_type": "markdown", 671 | "metadata": { 672 | "ExecuteTime": { 673 | "end_time": "2021-08-18T05:15:57.690026Z", 674 | "start_time": "2021-08-18T05:15:57.687056Z" 675 | } 676 | }, 677 | "source": [ 678 | "## Data 얻기\n", 679 | "\n", 680 | "* [Molecular Sets (MOSES): A benchmarking platform for molecular generation models](https://github.com/molecularsets/moses)에서 사용한 ZINC데이터를 random sampling을 통해 `train : test = 250000 : 30000`으로 나누었습니다.\n", 681 | "* 실제 데이터 파일 경로는 아래와 같습니다.\n", 682 | " * [`datasets/moses`](https://github.com/ilguyi/LAIDD-molecular-generation/blob/main/datasets/moses)\n", 683 | "* `get_rawdataset`함수를 이용하여 얻은 데이터는 각 항목이 SMILES `str`데이터로 이루어진 `np.ndarray`입니다.\n", 684 | "* 이 rawdataset을 사용하기 편하게 `custom Dataset` class를 만들었습니다.\n", 685 | " * `custom Dataset`을 만드는 간단한 예제는 [PyTorch tutorial](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html#creating-a-custom-dataset-for-your-files)에 있습니다.\n", 686 | " * 자세한 코드는 [`laiddmg/ᅟdatasets.py`](https://github.com/ilguyi/LAIDD-molecular-generation/blob/main/laiddmg/datasets.py)에 나와 있습니다." 687 | ] 688 | }, 689 | { 690 | "cell_type": "code", 691 | "execution_count": null, 692 | "metadata": { 693 | "ExecuteTime": { 694 | "end_time": "2021-08-29T06:40:07.741410Z", 695 | "start_time": "2021-08-29T06:40:07.557656Z" 696 | } 697 | }, 698 | "outputs": [], 699 | "source": [ 700 | "train = get_rawdataset('train')" 701 | ] 702 | }, 703 | { 704 | "cell_type": "code", 705 | "execution_count": null, 706 | "metadata": { 707 | "ExecuteTime": { 708 | "end_time": "2021-08-29T06:40:07.747506Z", 709 | "start_time": "2021-08-29T06:40:07.743394Z" 710 | } 711 | }, 712 | "outputs": [], 713 | "source": [ 714 | "train" 715 | ] 716 | }, 717 | { 718 | "cell_type": "code", 719 | "execution_count": null, 720 | "metadata": { 721 | "ExecuteTime": { 722 | "end_time": "2021-08-29T06:40:07.783581Z", 723 | "start_time": "2021-08-29T06:40:07.779802Z" 724 | } 725 | }, 726 | "outputs": [], 727 | "source": [ 728 | "print(f'number of training dataset: {len(train)}')\n", 729 | "print(f'raw data type: {type(train[0])}')" 730 | ] 731 | }, 732 | { 733 | "cell_type": "markdown", 734 | "metadata": {}, 735 | "source": [ 736 | "#### `model`에 sample data 적용해보기" 737 | ] 738 | }, 739 | { 740 | "cell_type": "code", 741 | "execution_count": null, 742 | "metadata": { 743 | "ExecuteTime": { 744 | "end_time": "2021-08-29T06:40:07.935152Z", 745 | "start_time": "2021-08-29T06:40:07.928041Z" 746 | } 747 | }, 748 | "outputs": [], 749 | "source": [ 750 | "sampled_data = train[:4]\n", 751 | "inputs = tokenizer(sampled_data)\n", 752 | "inputs" 753 | ] 754 | }, 755 | { 756 | "cell_type": "markdown", 757 | "metadata": {}, 758 | "source": [ 759 | "#### `tokenizer`를 이용해 token data를 character로 바꾸기" 760 | ] 761 | }, 762 | { 763 | "cell_type": "code", 764 | "execution_count": null, 765 | "metadata": { 766 | "ExecuteTime": { 767 | "end_time": "2021-08-29T06:40:08.054085Z", 768 | "start_time": "2021-08-29T06:40:08.048887Z" 769 | } 770 | }, 771 | "outputs": [], 772 | "source": [ 773 | "tokenizer.decode(inputs['input_ids'][0], skip_special_tokens=True)" 774 | ] 775 | }, 776 | { 777 | "cell_type": "code", 778 | "execution_count": null, 779 | "metadata": { 780 | "ExecuteTime": { 781 | "end_time": "2021-08-29T06:40:08.183875Z", 782 | "start_time": "2021-08-29T06:40:08.179254Z" 783 | } 784 | }, 785 | "outputs": [], 786 | "source": [ 787 | "sampled_data" 788 | ] 789 | }, 790 | { 791 | "cell_type": "markdown", 792 | "metadata": {}, 793 | "source": [ 794 | "#### `inputs`를 `model`의 입력값으로 넣기" 795 | ] 796 | }, 797 | { 798 | "cell_type": "code", 799 | "execution_count": null, 800 | "metadata": { 801 | "ExecuteTime": { 802 | "end_time": "2021-08-29T06:40:09.939305Z", 803 | "start_time": "2021-08-29T06:40:08.324270Z" 804 | } 805 | }, 806 | "outputs": [], 807 | "source": [ 808 | "outputs, hiddens = model(**inputs)" 809 | ] 810 | }, 811 | { 812 | "cell_type": "code", 813 | "execution_count": null, 814 | "metadata": { 815 | "ExecuteTime": { 816 | "end_time": "2021-08-29T06:40:09.947569Z", 817 | "start_time": "2021-08-29T06:40:09.942923Z" 818 | } 819 | }, 820 | "outputs": [], 821 | "source": [ 822 | "print(f'outputs shape: {outputs.shape}')\n", 823 | "print(f'hidden state shape: {hiddens[0].shape}')\n", 824 | "print(f'memory state shape: {hiddens[1].shape}')" 825 | ] 826 | }, 827 | { 828 | "cell_type": "markdown", 829 | "metadata": { 830 | "ExecuteTime": { 831 | "end_time": "2021-08-18T05:43:51.928351Z", 832 | "start_time": "2021-08-18T05:43:51.924835Z" 833 | } 834 | }, 835 | "source": [ 836 | "### PyTorch `Dataset`, `DataLoader` 얻기" 837 | ] 838 | }, 839 | { 840 | "cell_type": "code", 841 | "execution_count": null, 842 | "metadata": { 843 | "ExecuteTime": { 844 | "end_time": "2021-08-29T06:40:09.952712Z", 845 | "start_time": "2021-08-29T06:40:09.949895Z" 846 | } 847 | }, 848 | "outputs": [], 849 | "source": [ 850 | "from torch.utils.data import Dataset, DataLoader" 851 | ] 852 | }, 853 | { 854 | "cell_type": "code", 855 | "execution_count": null, 856 | "metadata": { 857 | "ExecuteTime": { 858 | "end_time": "2021-08-29T06:40:09.957123Z", 859 | "start_time": "2021-08-29T06:40:09.954667Z" 860 | } 861 | }, 862 | "outputs": [], 863 | "source": [ 864 | "train_dataset = get_dataset(train, tokenizer)" 865 | ] 866 | }, 867 | { 868 | "cell_type": "code", 869 | "execution_count": null, 870 | "metadata": { 871 | "ExecuteTime": { 872 | "end_time": "2021-08-29T06:40:09.964348Z", 873 | "start_time": "2021-08-29T06:40:09.958710Z" 874 | } 875 | }, 876 | "outputs": [], 877 | "source": [ 878 | "train_dataset[1000]" 879 | ] 880 | }, 881 | { 882 | "cell_type": "markdown", 883 | "metadata": { 884 | "ExecuteTime": { 885 | "end_time": "2021-08-18T06:53:56.747987Z", 886 | "start_time": "2021-08-18T06:53:56.743885Z" 887 | } 888 | }, 889 | "source": [ 890 | "#### `input_id`와 `target`의 관계\n", 891 | "\n", 892 | "RNN을 이용한 생성모델(generative model)은 language model의 학습방법을 이용한다.\n", 893 | "Language model은 간단하게 이야기하면 다음 단어를 예측하는 모델이다.\n", 894 | "다음 단어를 예측한다는 뜻은 RNN 그림을 보면 쉽게 이해할 수 있다.\n", 895 | "\n", 896 | "![RNN-input-target](https://user-images.githubusercontent.com/11681225/129859647-af31934a-0eea-4ad8-9a85-2d3c2a75f517.jpeg)\n", 897 | "\n", 898 | "위와 같이 input data의 token이 하나씩 이동한 것이 target이 되는 것이다." 899 | ] 900 | }, 901 | { 902 | "cell_type": "code", 903 | "execution_count": null, 904 | "metadata": { 905 | "ExecuteTime": { 906 | "end_time": "2021-08-29T06:40:09.969404Z", 907 | "start_time": "2021-08-29T06:40:09.966058Z" 908 | } 909 | }, 910 | "outputs": [], 911 | "source": [ 912 | "def _pad_sequence(data: List[torch.Tensor],\n", 913 | " padding_value: int = 0) -> torch.Tensor:\n", 914 | " return rnn_utils.pad_sequence(data,\n", 915 | " batch_first=True,\n", 916 | " padding_value=padding_value)" 917 | ] 918 | }, 919 | { 920 | "cell_type": "code", 921 | "execution_count": null, 922 | "metadata": { 923 | "ExecuteTime": { 924 | "end_time": "2021-08-29T06:40:09.977292Z", 925 | "start_time": "2021-08-29T06:40:09.971061Z" 926 | } 927 | }, 928 | "outputs": [], 929 | "source": [ 930 | "def _collate_fn(batch: List[Dict[str, Union[torch.Tensor, str, int]]],\n", 931 | " **kwargs) -> Dict[str, Union[torch.Tensor, List[str], List[int]]]:\n", 932 | "\n", 933 | " indexes = [item['index'] for item in batch]\n", 934 | " smiles = [item['smiles'] for item in batch]\n", 935 | " input_ids = [item['input_id'] for item in batch]\n", 936 | " targets = [item['target'] for item in batch]\n", 937 | " lengths = [item['length'] for item in batch]\n", 938 | "\n", 939 | " padding_value = tokenizer.padding_value\n", 940 | " input_ids = _pad_sequence(input_ids, padding_value)\n", 941 | " targets = _pad_sequence(targets, padding_value)\n", 942 | " lengths = torch.LongTensor(lengths)\n", 943 | "\n", 944 | " return {'input_ids': input_ids,\n", 945 | " 'targets': targets,\n", 946 | " 'lengths': lengths,\n", 947 | " 'smiles': smiles,\n", 948 | " 'indexes': indexes}" 949 | ] 950 | }, 951 | { 952 | "cell_type": "code", 953 | "execution_count": null, 954 | "metadata": { 955 | "ExecuteTime": { 956 | "end_time": "2021-08-29T06:40:09.981352Z", 957 | "start_time": "2021-08-29T06:40:09.978997Z" 958 | } 959 | }, 960 | "outputs": [], 961 | "source": [ 962 | "train_dataloader = DataLoader(train_dataset,\n", 963 | " batch_size=4,\n", 964 | " shuffle=True,\n", 965 | " collate_fn=_collate_fn,\n", 966 | " )" 967 | ] 968 | }, 969 | { 970 | "cell_type": "markdown", 971 | "metadata": {}, 972 | "source": [ 973 | "#### `pad_sequence` 작동 방식\n", 974 | "\n", 975 | "* 한 batch내의 sequence length가 다른 데이터들의 sequence length를 가장 긴 데이터를 기준으로 `padding_value`(일반적으로 0)를 채워넣어 길이를 맞춰줍니다.(`padding 한다`라고 부릅니다.)\n", 976 | "* 자세한 함수 설명은 [여기](https://pytorch.org/docs/stable/generated/torch.nn.utils.rnn.pad_sequence.html)에 있습니다." 977 | ] 978 | }, 979 | { 980 | "cell_type": "code", 981 | "execution_count": null, 982 | "metadata": { 983 | "ExecuteTime": { 984 | "end_time": "2021-08-29T06:40:09.987912Z", 985 | "start_time": "2021-08-29T06:40:09.982615Z" 986 | } 987 | }, 988 | "outputs": [], 989 | "source": [ 990 | "input_ids = [train_dataset[i]['input_id'] for i in range(4)]\n", 991 | "input_ids" 992 | ] 993 | }, 994 | { 995 | "cell_type": "code", 996 | "execution_count": null, 997 | "metadata": { 998 | "ExecuteTime": { 999 | "end_time": "2021-08-29T06:40:09.994488Z", 1000 | "start_time": "2021-08-29T06:40:09.989555Z" 1001 | } 1002 | }, 1003 | "outputs": [], 1004 | "source": [ 1005 | "rnn_utils.pad_sequence(input_ids, batch_first=True, padding_value=0)" 1006 | ] 1007 | }, 1008 | { 1009 | "cell_type": "markdown", 1010 | "metadata": {}, 1011 | "source": [ 1012 | "#### `train_dataloader` 확인" 1013 | ] 1014 | }, 1015 | { 1016 | "cell_type": "code", 1017 | "execution_count": null, 1018 | "metadata": { 1019 | "ExecuteTime": { 1020 | "end_time": "2021-08-29T06:40:12.257401Z", 1021 | "start_time": "2021-08-29T06:40:12.237292Z" 1022 | } 1023 | }, 1024 | "outputs": [], 1025 | "source": [ 1026 | "batch_data = next(iter(train_dataloader))" 1027 | ] 1028 | }, 1029 | { 1030 | "cell_type": "code", 1031 | "execution_count": null, 1032 | "metadata": { 1033 | "ExecuteTime": { 1034 | "end_time": "2021-08-29T06:40:12.399717Z", 1035 | "start_time": "2021-08-29T06:40:12.395824Z" 1036 | } 1037 | }, 1038 | "outputs": [], 1039 | "source": [ 1040 | "batch_data.keys()" 1041 | ] 1042 | }, 1043 | { 1044 | "cell_type": "code", 1045 | "execution_count": null, 1046 | "metadata": { 1047 | "ExecuteTime": { 1048 | "end_time": "2021-08-29T06:40:12.566581Z", 1049 | "start_time": "2021-08-29T06:40:12.561334Z" 1050 | } 1051 | }, 1052 | "outputs": [], 1053 | "source": [ 1054 | "batch_data['input_ids']" 1055 | ] 1056 | }, 1057 | { 1058 | "cell_type": "code", 1059 | "execution_count": null, 1060 | "metadata": { 1061 | "ExecuteTime": { 1062 | "end_time": "2021-08-29T06:40:12.708655Z", 1063 | "start_time": "2021-08-29T06:40:12.703540Z" 1064 | } 1065 | }, 1066 | "outputs": [], 1067 | "source": [ 1068 | "batch_data['targets']" 1069 | ] 1070 | }, 1071 | { 1072 | "cell_type": "code", 1073 | "execution_count": null, 1074 | "metadata": { 1075 | "ExecuteTime": { 1076 | "end_time": "2021-08-29T06:40:12.858635Z", 1077 | "start_time": "2021-08-29T06:40:12.854207Z" 1078 | } 1079 | }, 1080 | "outputs": [], 1081 | "source": [ 1082 | "batch_data['lengths']" 1083 | ] 1084 | }, 1085 | { 1086 | "cell_type": "code", 1087 | "execution_count": null, 1088 | "metadata": { 1089 | "ExecuteTime": { 1090 | "end_time": "2021-08-29T06:40:13.019871Z", 1091 | "start_time": "2021-08-29T06:40:13.015652Z" 1092 | } 1093 | }, 1094 | "outputs": [], 1095 | "source": [ 1096 | "batch_data['smiles']" 1097 | ] 1098 | }, 1099 | { 1100 | "cell_type": "code", 1101 | "execution_count": null, 1102 | "metadata": { 1103 | "ExecuteTime": { 1104 | "end_time": "2021-08-29T06:40:13.203012Z", 1105 | "start_time": "2021-08-29T06:40:13.198749Z" 1106 | } 1107 | }, 1108 | "outputs": [], 1109 | "source": [ 1110 | "batch_data['indexes']" 1111 | ] 1112 | }, 1113 | { 1114 | "cell_type": "markdown", 1115 | "metadata": { 1116 | "ExecuteTime": { 1117 | "end_time": "2021-08-18T23:33:34.497465Z", 1118 | "start_time": "2021-08-18T23:33:34.494157Z" 1119 | } 1120 | }, 1121 | "source": [ 1122 | "### Train without `Trainer` class\n", 1123 | "\n", 1124 | "* 실제 사용할 수 있게 패키징한 코드에서는 `Trainer` class를 만들어 사용하기 편리하게 모듈화 시켰습니다.\n", 1125 | "* 하지만 해당 Jupyter notebook은 이해를 돕기위해 모듈화 되어 있는 코드를 풀어서 블록 단위로 나타내었습니다.\n", 1126 | "* `Trainer`에 관련된 자세한 코드는 아래 링크에 있습니다.\n", 1127 | " * [`laiddmg/trainer.py`](https://github.com/ilguyi/LAIDD-molecular-generation/blob/main/laiddmg/trainer.py)\n", 1128 | " * [`laiddmg/models/char_rnn/char_rnn_trainer.py`](https://github.com/ilguyi/LAIDD-molecular-generation/blob/main/laiddmg/models/char_rnn/char_rnn_trainer.py)" 1129 | ] 1130 | }, 1131 | { 1132 | "cell_type": "markdown", 1133 | "metadata": {}, 1134 | "source": [ 1135 | "#### loss function and optimizer 생성" 1136 | ] 1137 | }, 1138 | { 1139 | "cell_type": "code", 1140 | "execution_count": null, 1141 | "metadata": { 1142 | "ExecuteTime": { 1143 | "end_time": "2021-08-30T02:52:36.901836Z", 1144 | "start_time": "2021-08-30T02:52:36.897139Z" 1145 | } 1146 | }, 1147 | "outputs": [], 1148 | "source": [ 1149 | "training_args = EasyDict({\n", 1150 | " 'output_dir': 'outputs/char_rnn/jupyter1',\n", 1151 | " 'num_train_epochs': 10,\n", 1152 | " 'batch_size': 256,\n", 1153 | " 'lr': 1e-3,\n", 1154 | " 'step_size': 10,\n", 1155 | " 'gamma': 0.5, \n", 1156 | "})" 1157 | ] 1158 | }, 1159 | { 1160 | "cell_type": "code", 1161 | "execution_count": null, 1162 | "metadata": { 1163 | "ExecuteTime": { 1164 | "end_time": "2021-08-29T06:40:16.574383Z", 1165 | "start_time": "2021-08-29T06:40:16.570911Z" 1166 | } 1167 | }, 1168 | "outputs": [], 1169 | "source": [ 1170 | "train_dataloader = DataLoader(train_dataset,\n", 1171 | " batch_size=training_args.batch_size,\n", 1172 | " shuffle=True,\n", 1173 | " collate_fn=_collate_fn,\n", 1174 | " )" 1175 | ] 1176 | }, 1177 | { 1178 | "cell_type": "code", 1179 | "execution_count": null, 1180 | "metadata": { 1181 | "ExecuteTime": { 1182 | "end_time": "2021-08-29T06:40:17.169980Z", 1183 | "start_time": "2021-08-29T06:40:17.164923Z" 1184 | } 1185 | }, 1186 | "outputs": [], 1187 | "source": [ 1188 | "loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer.padding_value)\n", 1189 | "optimizer = optim.Adam(model.parameters(), lr=training_args.lr)\n", 1190 | "lr_scheduler = optim.lr_scheduler.StepLR(optimizer, training_args.step_size, training_args.gamma)" 1191 | ] 1192 | }, 1193 | { 1194 | "cell_type": "markdown", 1195 | "metadata": {}, 1196 | "source": [ 1197 | "### Plot for `lr_scheduler`\n", 1198 | "\n", 1199 | "* 학습할 때 더 빠른 수렴속도와 더 나은 정확도를 위해 learning rate를 조절하면서 학습하는 방식을 `learning rate scheduling`이라고 부릅니다.\n", 1200 | "* PyTorch에는 다양한 scheduler들이 잘 정리되어 있습니다.\n", 1201 | " * [여기](https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate)를 참조하면 다양한 scheduler들을 볼 수 있습니다.\n", 1202 | " * 이 튜토리얼에서 사용하는 `StepLR scheduler`는 다음 [링크](https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.StepLR.html)에서 확인할 수 있습니다." 1203 | ] 1204 | }, 1205 | { 1206 | "cell_type": "code", 1207 | "execution_count": null, 1208 | "metadata": { 1209 | "ExecuteTime": { 1210 | "end_time": "2021-08-29T06:40:19.970371Z", 1211 | "start_time": "2021-08-29T06:40:19.770519Z" 1212 | } 1213 | }, 1214 | "outputs": [], 1215 | "source": [ 1216 | "import matplotlib.pyplot as plt\n", 1217 | "%matplotlib inline" 1218 | ] 1219 | }, 1220 | { 1221 | "cell_type": "code", 1222 | "execution_count": null, 1223 | "metadata": { 1224 | "ExecuteTime": { 1225 | "end_time": "2021-08-29T06:40:20.430025Z", 1226 | "start_time": "2021-08-29T06:40:20.423403Z" 1227 | } 1228 | }, 1229 | "outputs": [], 1230 | "source": [ 1231 | "lr_history = []\n", 1232 | "for _ in range(50):\n", 1233 | " lr_history.append(lr_scheduler.get_last_lr()[0])\n", 1234 | " lr_scheduler.step()" 1235 | ] 1236 | }, 1237 | { 1238 | "cell_type": "code", 1239 | "execution_count": null, 1240 | "metadata": { 1241 | "ExecuteTime": { 1242 | "end_time": "2021-08-29T06:40:22.573150Z", 1243 | "start_time": "2021-08-29T06:40:22.481208Z" 1244 | } 1245 | }, 1246 | "outputs": [], 1247 | "source": [ 1248 | "plt.plot(lr_history)" 1249 | ] 1250 | }, 1251 | { 1252 | "cell_type": "markdown", 1253 | "metadata": {}, 1254 | "source": [ 1255 | "### Training" 1256 | ] 1257 | }, 1258 | { 1259 | "cell_type": "code", 1260 | "execution_count": null, 1261 | "metadata": { 1262 | "ExecuteTime": { 1263 | "end_time": "2021-08-29T06:40:25.848046Z", 1264 | "start_time": "2021-08-29T06:40:25.844330Z" 1265 | } 1266 | }, 1267 | "outputs": [], 1268 | "source": [ 1269 | "# 위에서 그림을 그리기 위해 `lr_scheduler`를 업데이트 했기때문에 다시 생성해줍니다.x\n", 1270 | "lr_scheduler = optim.lr_scheduler.StepLR(optimizer, training_args.step_size, training_args.gamma)" 1271 | ] 1272 | }, 1273 | { 1274 | "cell_type": "code", 1275 | "execution_count": null, 1276 | "metadata": { 1277 | "ExecuteTime": { 1278 | "end_time": "2021-08-30T04:41:35.037165Z", 1279 | "start_time": "2021-08-30T04:41:34.335087Z" 1280 | } 1281 | }, 1282 | "outputs": [], 1283 | "source": [ 1284 | "device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n", 1285 | "print(device)" 1286 | ] 1287 | }, 1288 | { 1289 | "cell_type": "code", 1290 | "execution_count": null, 1291 | "metadata": { 1292 | "ExecuteTime": { 1293 | "end_time": "2021-08-29T06:40:34.109285Z", 1294 | "start_time": "2021-08-29T06:40:30.937187Z" 1295 | } 1296 | }, 1297 | "outputs": [], 1298 | "source": [ 1299 | "model = model.to(device)\n", 1300 | "print(model.device)" 1301 | ] 1302 | }, 1303 | { 1304 | "cell_type": "code", 1305 | "execution_count": null, 1306 | "metadata": { 1307 | "ExecuteTime": { 1308 | "end_time": "2021-08-29T06:40:34.115505Z", 1309 | "start_time": "2021-08-29T06:40:34.111637Z" 1310 | } 1311 | }, 1312 | "outputs": [], 1313 | "source": [ 1314 | "def save_model(epoch: int, global_step: int, model: nn.Module):\n", 1315 | " checkpoint_dir = os.path.join(training_args.output_dir)\n", 1316 | " os.makedirs(checkpoint_dir, exist_ok=True)\n", 1317 | " ckpt_name = f'ckpt_{epoch:03d}.pt'\n", 1318 | " ckpt_path = os.path.join(checkpoint_dir, ckpt_name)\n", 1319 | " \n", 1320 | " torch.save({'epoch': epoch,\n", 1321 | " 'global_step': global_step,\n", 1322 | " 'model_state_dict': model.state_dict()},\n", 1323 | " ckpt_path)\n", 1324 | " print(f'saved {model.config.model_type} model at epoch {epoch}.')" 1325 | ] 1326 | }, 1327 | { 1328 | "cell_type": "code", 1329 | "execution_count": null, 1330 | "metadata": { 1331 | "ExecuteTime": { 1332 | "end_time": "2021-08-29T06:48:43.350278Z", 1333 | "start_time": "2021-08-29T06:40:49.063570Z" 1334 | }, 1335 | "scrolled": true 1336 | }, 1337 | "outputs": [], 1338 | "source": [ 1339 | "model.train()\n", 1340 | "global_step = 0\n", 1341 | "\n", 1342 | "for epoch in range(1, training_args.num_train_epochs + 1):\n", 1343 | " print(f'\\nStart training: {epoch} Epoch\\n')\n", 1344 | " \n", 1345 | " for i, data in enumerate(train_dataloader, 1):\n", 1346 | " optimizer.zero_grad()\n", 1347 | " \n", 1348 | " data['input_ids'] = data['input_ids'].to(device)\n", 1349 | " data['targets'] = data['targets'].to(device)\n", 1350 | " outputs, _ = model(data['input_ids'], data['lengths'])\n", 1351 | " \n", 1352 | " loss = loss_fn(outputs.view(-1, outputs.shape[-1]),\n", 1353 | " data['targets'].view(-1))\n", 1354 | " \n", 1355 | " loss.backward()\n", 1356 | " optimizer.step()\n", 1357 | " global_step += 1\n", 1358 | " \n", 1359 | " if global_step % 100 == 0:\n", 1360 | " print(f'{epoch} Epochs | {i}/{len(train_dataloader)} | loss: {loss.item():.4g} | '\n", 1361 | " f'lr: {lr_scheduler.get_last_lr()[0]:.4g}')\n", 1362 | " \n", 1363 | " lr_scheduler.step()\n", 1364 | " \n", 1365 | " save_model(epoch, global_step, model)\n", 1366 | " \n", 1367 | "print('Training done!!')" 1368 | ] 1369 | }, 1370 | { 1371 | "cell_type": "markdown", 1372 | "metadata": {}, 1373 | "source": [ 1374 | "## Generate new SMILES\n", 1375 | "\n", 1376 | "* model을 학습한 후에는 학습된 모델을 `load`하여 SMILES를 생성할 준비를 합니다.\n", 1377 | "* `model.generate`함수를 이용하면 새로운 SMILES sequence를 만들수 있습니다.\n", 1378 | "* 여기서는 generation의 각 과정을 하나씩 설명합니다.\n", 1379 | "* 자세한 코드는 [`laiddmg/generate.py`](https://github.com/ilguyi/LAIDD-molecular-generation/blob/main/laiddmg/generate.py)에 나와 있습니다." 1380 | ] 1381 | }, 1382 | { 1383 | "cell_type": "code", 1384 | "execution_count": null, 1385 | "metadata": { 1386 | "ExecuteTime": { 1387 | "end_time": "2021-08-30T02:53:52.548186Z", 1388 | "start_time": "2021-08-30T02:53:49.341338Z" 1389 | } 1390 | }, 1391 | "outputs": [], 1392 | "source": [ 1393 | "# checkpoint_dir = training_args.output_dir\n", 1394 | "# model = CharRNNModel.from_pretrained(config,\n", 1395 | "# os.path.join(f'{checkpoint_dir}',\n", 1396 | "# f'ckpt_{training_args.num_train_epochs:03d}.pt'))\n", 1397 | "# model = model.to(device)\n", 1398 | "# model.eval()" 1399 | ] 1400 | }, 1401 | { 1402 | "cell_type": "markdown", 1403 | "metadata": {}, 1404 | "source": [ 1405 | "* 본 수업에서는 시간관계상 미리 학습한 `best_model`을 다운 받아서 씁니다." 1406 | ] 1407 | }, 1408 | { 1409 | "cell_type": "code", 1410 | "execution_count": null, 1411 | "metadata": { 1412 | "ExecuteTime": { 1413 | "end_time": "2021-08-30T04:47:40.590550Z", 1414 | "start_time": "2021-08-30T04:47:32.147704Z" 1415 | } 1416 | }, 1417 | "outputs": [], 1418 | "source": [ 1419 | "!wget 'https://www.dropbox.com/s/fqwpvx6nfh2ba1p/char_rnn_best.tar.gz?dl=0'\n", 1420 | "!tar xvzf char_rnn_best.tar.gz?dl=0\n", 1421 | "!rm -f char_rnn_best.tar.gz?dl=0\n", 1422 | "!mv best_model/ outputs/char_rnn/" 1423 | ] 1424 | }, 1425 | { 1426 | "cell_type": "code", 1427 | "execution_count": null, 1428 | "metadata": { 1429 | "ExecuteTime": { 1430 | "end_time": "2021-08-30T04:48:24.270682Z", 1431 | "start_time": "2021-08-30T04:48:21.252646Z" 1432 | } 1433 | }, 1434 | "outputs": [], 1435 | "source": [ 1436 | "model = CharRNNModel.from_pretrained(config,\n", 1437 | " os.path.join('./outputs/char_rnn/best_model/best_model.pt'))\n", 1438 | "model = model.to(device)\n", 1439 | "model.eval()" 1440 | ] 1441 | }, 1442 | { 1443 | "cell_type": "code", 1444 | "execution_count": null, 1445 | "metadata": { 1446 | "ExecuteTime": { 1447 | "end_time": "2021-08-30T04:48:26.480258Z", 1448 | "start_time": "2021-08-30T04:48:26.477650Z" 1449 | } 1450 | }, 1451 | "outputs": [], 1452 | "source": [ 1453 | "batch_size_for_generate = 4" 1454 | ] 1455 | }, 1456 | { 1457 | "cell_type": "code", 1458 | "execution_count": null, 1459 | "metadata": { 1460 | "ExecuteTime": { 1461 | "end_time": "2021-08-30T04:48:26.765576Z", 1462 | "start_time": "2021-08-30T04:48:26.671926Z" 1463 | } 1464 | }, 1465 | "outputs": [], 1466 | "source": [ 1467 | "outputs = model.generate(tokenizer=tokenizer,\n", 1468 | " max_length=128,\n", 1469 | " num_return_sequences=batch_size_for_generate,\n", 1470 | " skip_special_tokens=True)" 1471 | ] 1472 | }, 1473 | { 1474 | "cell_type": "code", 1475 | "execution_count": null, 1476 | "metadata": { 1477 | "ExecuteTime": { 1478 | "end_time": "2021-08-30T04:48:27.425610Z", 1479 | "start_time": "2021-08-30T04:48:27.417480Z" 1480 | } 1481 | }, 1482 | "outputs": [], 1483 | "source": [ 1484 | "outputs" 1485 | ] 1486 | }, 1487 | { 1488 | "cell_type": "markdown", 1489 | "metadata": { 1490 | "ExecuteTime": { 1491 | "end_time": "2021-08-20T01:23:46.853695Z", 1492 | "start_time": "2021-08-20T01:23:46.850494Z" 1493 | } 1494 | }, 1495 | "source": [ 1496 | "### generation 과정 step by step으로 알아보기\n", 1497 | "\n", 1498 | "1. `input_ids`변수에 첫 번째 token 데이터인 ` token` 넣기\n", 1499 | "2. `input_ids`데이터를 (embedding 후에) lstm 모듈에 넣어 `outputs`과 `hidden state`를 얻기\n", 1500 | "3. `outputs`을 `Linear`레이어를 통과시켜서 `next_token_logits`을 얻기\n", 1501 | "4. `next_token_logits`을 `softmax`를 통해 확률분포를 얻음\n", 1502 | "5. 이 확률분포를 기반한 sampling 작업을 함 (`torch.multinomial`을 이용)\n", 1503 | "6. 실제로 sampling된 값이 `next_tokens`이 되고 이게 다음 스텝의 rnn 인풋으로 쓰임 (`input_ids = next_tokens`)\n", 1504 | "7. 2 ~ 6과정을 반복" 1505 | ] 1506 | }, 1507 | { 1508 | "cell_type": "markdown", 1509 | "metadata": { 1510 | "ExecuteTime": { 1511 | "end_time": "2021-08-20T02:15:33.183458Z", 1512 | "start_time": "2021-08-20T02:15:33.172551Z" 1513 | } 1514 | }, 1515 | "source": [ 1516 | "#### step 1. `input_ids`변수에 ` token` 넣기" 1517 | ] 1518 | }, 1519 | { 1520 | "cell_type": "code", 1521 | "execution_count": null, 1522 | "metadata": { 1523 | "ExecuteTime": { 1524 | "end_time": "2021-08-29T06:59:36.986082Z", 1525 | "start_time": "2021-08-29T06:59:36.972456Z" 1526 | } 1527 | }, 1528 | "outputs": [], 1529 | "source": [ 1530 | "initial_inputs = torch.full((batch_size_for_generate, 1),\n", 1531 | " tokenizer.convert_token_to_id(tokenizer.start_token),\n", 1532 | " dtype=torch.long,\n", 1533 | " device=model.device)\n", 1534 | "generated_sequences = initial_inputs\n", 1535 | "input_ids = initial_inputs\n", 1536 | "hiddens = model.reset_states(batch_size_for_generate)" 1537 | ] 1538 | }, 1539 | { 1540 | "cell_type": "code", 1541 | "execution_count": null, 1542 | "metadata": { 1543 | "ExecuteTime": { 1544 | "end_time": "2021-08-29T06:59:37.277541Z", 1545 | "start_time": "2021-08-29T06:59:37.272904Z" 1546 | } 1547 | }, 1548 | "outputs": [], 1549 | "source": [ 1550 | "input_ids" 1551 | ] 1552 | }, 1553 | { 1554 | "cell_type": "code", 1555 | "execution_count": null, 1556 | "metadata": { 1557 | "ExecuteTime": { 1558 | "end_time": "2021-08-29T06:59:37.287049Z", 1559 | "start_time": "2021-08-29T06:59:37.279207Z" 1560 | } 1561 | }, 1562 | "outputs": [], 1563 | "source": [ 1564 | "hiddens" 1565 | ] 1566 | }, 1567 | { 1568 | "cell_type": "markdown", 1569 | "metadata": { 1570 | "ExecuteTime": { 1571 | "end_time": "2021-08-20T02:15:33.225361Z", 1572 | "start_time": "2021-08-20T02:15:33.192776Z" 1573 | } 1574 | }, 1575 | "source": [ 1576 | "#### step 2. `input_ids`데이터를 (embedding 후에) lstm 모듈에 넣어 `outputs`과 `hidden state`를 얻기" 1577 | ] 1578 | }, 1579 | { 1580 | "cell_type": "code", 1581 | "execution_count": null, 1582 | "metadata": { 1583 | "ExecuteTime": { 1584 | "end_time": "2021-08-29T06:59:37.478962Z", 1585 | "start_time": "2021-08-29T06:59:37.476442Z" 1586 | } 1587 | }, 1588 | "outputs": [], 1589 | "source": [ 1590 | "x = model.embeddings(input_ids)\n", 1591 | "x, hiddens = model.lstm(x, hiddens)" 1592 | ] 1593 | }, 1594 | { 1595 | "cell_type": "code", 1596 | "execution_count": null, 1597 | "metadata": { 1598 | "ExecuteTime": { 1599 | "end_time": "2021-08-29T06:59:37.503401Z", 1600 | "start_time": "2021-08-29T06:59:37.499400Z" 1601 | } 1602 | }, 1603 | "outputs": [], 1604 | "source": [ 1605 | "x.shape" 1606 | ] 1607 | }, 1608 | { 1609 | "cell_type": "markdown", 1610 | "metadata": { 1611 | "ExecuteTime": { 1612 | "end_time": "2021-08-20T02:15:33.243819Z", 1613 | "start_time": "2021-08-20T02:15:33.230683Z" 1614 | } 1615 | }, 1616 | "source": [ 1617 | "#### step 3. `outputs`을 `Linear`레이어를 통과시켜서 `next_token_logits`을 얻기" 1618 | ] 1619 | }, 1620 | { 1621 | "cell_type": "code", 1622 | "execution_count": null, 1623 | "metadata": { 1624 | "ExecuteTime": { 1625 | "end_time": "2021-08-29T06:59:37.687381Z", 1626 | "start_time": "2021-08-29T06:59:37.684832Z" 1627 | } 1628 | }, 1629 | "outputs": [], 1630 | "source": [ 1631 | "logits = model.fc(x)\n", 1632 | "next_token_logits = logits.squeeze(1)" 1633 | ] 1634 | }, 1635 | { 1636 | "cell_type": "code", 1637 | "execution_count": null, 1638 | "metadata": { 1639 | "ExecuteTime": { 1640 | "end_time": "2021-08-29T06:59:37.792288Z", 1641 | "start_time": "2021-08-29T06:59:37.788169Z" 1642 | } 1643 | }, 1644 | "outputs": [], 1645 | "source": [ 1646 | "logits.shape" 1647 | ] 1648 | }, 1649 | { 1650 | "cell_type": "code", 1651 | "execution_count": null, 1652 | "metadata": { 1653 | "ExecuteTime": { 1654 | "end_time": "2021-08-29T06:59:37.808641Z", 1655 | "start_time": "2021-08-29T06:59:37.805009Z" 1656 | } 1657 | }, 1658 | "outputs": [], 1659 | "source": [ 1660 | "next_token_logits.shape" 1661 | ] 1662 | }, 1663 | { 1664 | "cell_type": "markdown", 1665 | "metadata": { 1666 | "ExecuteTime": { 1667 | "end_time": "2021-08-20T02:15:33.257006Z", 1668 | "start_time": "2021-08-20T02:15:33.254591Z" 1669 | } 1670 | }, 1671 | "source": [ 1672 | "#### step 4. `next_token_logits`을 `softmax`를 통해 확률분포를 얻음" 1673 | ] 1674 | }, 1675 | { 1676 | "cell_type": "code", 1677 | "execution_count": null, 1678 | "metadata": { 1679 | "ExecuteTime": { 1680 | "end_time": "2021-08-29T06:59:38.012787Z", 1681 | "start_time": "2021-08-29T06:59:38.009952Z" 1682 | } 1683 | }, 1684 | "outputs": [], 1685 | "source": [ 1686 | "probabilities = F.softmax(next_token_logits, dim=-1)" 1687 | ] 1688 | }, 1689 | { 1690 | "cell_type": "code", 1691 | "execution_count": null, 1692 | "metadata": { 1693 | "ExecuteTime": { 1694 | "end_time": "2021-08-29T06:59:38.096026Z", 1695 | "start_time": "2021-08-29T06:59:38.089875Z" 1696 | } 1697 | }, 1698 | "outputs": [], 1699 | "source": [ 1700 | "probabilities[0]" 1701 | ] 1702 | }, 1703 | { 1704 | "cell_type": "markdown", 1705 | "metadata": { 1706 | "ExecuteTime": { 1707 | "end_time": "2021-08-20T02:15:33.267899Z", 1708 | "start_time": "2021-08-20T02:15:33.263783Z" 1709 | } 1710 | }, 1711 | "source": [ 1712 | "#### step 5. 이 확률분포를 기반한 sampling 작업을 함 ([`torch.multinomial`](https://pytorch.org/docs/stable/generated/torch.multinomial.html?highlight=multinomial#torch.multinomial)을 이용)" 1713 | ] 1714 | }, 1715 | { 1716 | "cell_type": "code", 1717 | "execution_count": null, 1718 | "metadata": { 1719 | "ExecuteTime": { 1720 | "end_time": "2021-08-29T06:59:38.299402Z", 1721 | "start_time": "2021-08-29T06:59:38.294912Z" 1722 | } 1723 | }, 1724 | "outputs": [], 1725 | "source": [ 1726 | "next_tokens = torch.multinomial(probabilities, num_samples=1)\n", 1727 | "next_tokens" 1728 | ] 1729 | }, 1730 | { 1731 | "cell_type": "markdown", 1732 | "metadata": { 1733 | "ExecuteTime": { 1734 | "end_time": "2021-08-20T01:49:41.633366Z", 1735 | "start_time": "2021-08-20T01:49:41.626309Z" 1736 | } 1737 | }, 1738 | "source": [ 1739 | "참고 `tokenizer.vocab`\n", 1740 | "\n", 1741 | "```python\n", 1742 | "{('#', 4), ('(', 5), (')', 6), ('-', 7),\n", 1743 | " ('1', 8), ('2', 9), ('3', 10), ('4', 11), ('5', 12), ('6', 13), ('=', 14),\n", 1744 | " ('B', 15), ('C', 16), ('F', 17), ('H', 18), ('N', 19), ('O', 20), ('S', 21),\n", 1745 | " ('[', 22), (']', 23), ('c', 24), ('l', 25), ('n', 26), ('o', 27), ('r', 28), ('s', 29)}\n", 1746 | "```" 1747 | ] 1748 | }, 1749 | { 1750 | "cell_type": "markdown", 1751 | "metadata": { 1752 | "ExecuteTime": { 1753 | "end_time": "2021-08-20T02:15:33.272783Z", 1754 | "start_time": "2021-08-20T02:15:33.269071Z" 1755 | } 1756 | }, 1757 | "source": [ 1758 | "#### step 6. 실제로 sampling된 값이 `next_tokens`이 되고 이게 다음 스텝의 rnn 인풋으로 쓰임 (`input_ids = next_tokens`)" 1759 | ] 1760 | }, 1761 | { 1762 | "cell_type": "code", 1763 | "execution_count": null, 1764 | "metadata": { 1765 | "ExecuteTime": { 1766 | "end_time": "2021-08-29T06:59:55.394155Z", 1767 | "start_time": "2021-08-29T06:59:55.387802Z" 1768 | } 1769 | }, 1770 | "outputs": [], 1771 | "source": [ 1772 | "inputs_ids = next_tokens\n", 1773 | "generated_sequences = torch.cat((generated_sequences, next_tokens), dim=1)\n", 1774 | "generated_sequences" 1775 | ] 1776 | }, 1777 | { 1778 | "cell_type": "markdown", 1779 | "metadata": { 1780 | "ExecuteTime": { 1781 | "end_time": "2021-08-20T02:04:21.601326Z", 1782 | "start_time": "2021-08-20T02:04:21.598478Z" 1783 | } 1784 | }, 1785 | "source": [ 1786 | "#### 위의 과정을 모듈화해서 `generate`함수를 만들었습니다." 1787 | ] 1788 | }, 1789 | { 1790 | "cell_type": "code", 1791 | "execution_count": null, 1792 | "metadata": { 1793 | "ExecuteTime": { 1794 | "end_time": "2021-08-30T04:48:41.399699Z", 1795 | "start_time": "2021-08-30T04:48:41.174496Z" 1796 | } 1797 | }, 1798 | "outputs": [], 1799 | "source": [ 1800 | "outputs = model.generate(tokenizer=tokenizer,\n", 1801 | " max_length=128,\n", 1802 | " # num_return_sequences=batch_size_for_generate,\n", 1803 | " num_return_sequences=128,\n", 1804 | " skip_special_tokens=True)" 1805 | ] 1806 | }, 1807 | { 1808 | "cell_type": "code", 1809 | "execution_count": null, 1810 | "metadata": { 1811 | "ExecuteTime": { 1812 | "end_time": "2021-08-30T04:48:41.462953Z", 1813 | "start_time": "2021-08-30T04:48:41.401356Z" 1814 | } 1815 | }, 1816 | "outputs": [], 1817 | "source": [ 1818 | "import rdkit\n", 1819 | "from rdkit import Chem\n", 1820 | "from rdkit.Chem.Draw import IPythonConsole" 1821 | ] 1822 | }, 1823 | { 1824 | "cell_type": "code", 1825 | "execution_count": null, 1826 | "metadata": { 1827 | "ExecuteTime": { 1828 | "end_time": "2021-08-30T04:48:41.532597Z", 1829 | "start_time": "2021-08-30T04:48:41.516620Z" 1830 | }, 1831 | "scrolled": true 1832 | }, 1833 | "outputs": [], 1834 | "source": [ 1835 | "mols = []\n", 1836 | "for s in outputs:\n", 1837 | " try:\n", 1838 | " mol = Chem.MolFromSmiles(s)\n", 1839 | " except:\n", 1840 | " pass\n", 1841 | " if mol is not None:\n", 1842 | " mols.append(mol)" 1843 | ] 1844 | }, 1845 | { 1846 | "cell_type": "code", 1847 | "execution_count": null, 1848 | "metadata": { 1849 | "ExecuteTime": { 1850 | "end_time": "2021-08-30T04:48:41.690935Z", 1851 | "start_time": "2021-08-30T04:48:41.687267Z" 1852 | } 1853 | }, 1854 | "outputs": [], 1855 | "source": [ 1856 | "len(mols)" 1857 | ] 1858 | }, 1859 | { 1860 | "cell_type": "code", 1861 | "execution_count": null, 1862 | "metadata": { 1863 | "ExecuteTime": { 1864 | "end_time": "2021-08-30T04:48:41.860569Z", 1865 | "start_time": "2021-08-30T04:48:41.852326Z" 1866 | } 1867 | }, 1868 | "outputs": [], 1869 | "source": [ 1870 | "mols[0]" 1871 | ] 1872 | }, 1873 | { 1874 | "cell_type": "code", 1875 | "execution_count": null, 1876 | "metadata": { 1877 | "ExecuteTime": { 1878 | "end_time": "2021-08-30T04:48:42.004226Z", 1879 | "start_time": "2021-08-30T04:48:41.996902Z" 1880 | } 1881 | }, 1882 | "outputs": [], 1883 | "source": [ 1884 | "mols[1]" 1885 | ] 1886 | }, 1887 | { 1888 | "cell_type": "code", 1889 | "execution_count": null, 1890 | "metadata": { 1891 | "ExecuteTime": { 1892 | "end_time": "2021-08-30T04:48:42.205290Z", 1893 | "start_time": "2021-08-30T04:48:42.197579Z" 1894 | } 1895 | }, 1896 | "outputs": [], 1897 | "source": [ 1898 | "mols[2]" 1899 | ] 1900 | }, 1901 | { 1902 | "cell_type": "code", 1903 | "execution_count": null, 1904 | "metadata": {}, 1905 | "outputs": [], 1906 | "source": [] 1907 | } 1908 | ], 1909 | "metadata": { 1910 | "kernelspec": { 1911 | "display_name": "Python [conda env:laiddmg] *", 1912 | "language": "python", 1913 | "name": "conda-env-laiddmg-py" 1914 | }, 1915 | "language_info": { 1916 | "codemirror_mode": { 1917 | "name": "ipython", 1918 | "version": 3 1919 | }, 1920 | "file_extension": ".py", 1921 | "mimetype": "text/x-python", 1922 | "name": "python", 1923 | "nbconvert_exporter": "python", 1924 | "pygments_lexer": "ipython3", 1925 | "version": "3.7.11" 1926 | } 1927 | }, 1928 | "nbformat": 4, 1929 | "nbformat_minor": 4 1930 | } 1931 | -------------------------------------------------------------------------------- /laiddmg/jupyter_vae.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Variational AutoEncoder(VAE) 모델 설명 및 학습과 생성" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": { 14 | "ExecuteTime": { 15 | "end_time": "2021-08-30T11:35:36.440899Z", 16 | "start_time": "2021-08-30T11:35:35.762140Z" 17 | } 18 | }, 19 | "outputs": [], 20 | "source": [ 21 | "import os\n", 22 | "from easydict import EasyDict\n", 23 | "from typing import List, Tuple, Dict, Union\n", 24 | "\n", 25 | "from laiddmg import (\n", 26 | " VAEConfig,\n", 27 | " Tokenizer,\n", 28 | " VAEModel,\n", 29 | " get_rawdataset,\n", 30 | " get_dataset,\n", 31 | ")\n", 32 | "\n", 33 | "import torch\n", 34 | "import torch.nn as nn\n", 35 | "import torch.nn.functional as F\n", 36 | "import torch.nn.utils.rnn as rnn_utils\n", 37 | "import torch.optim as optim" 38 | ] 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "metadata": { 43 | "ExecuteTime": { 44 | "end_time": "2021-08-18T01:17:37.048161Z", 45 | "start_time": "2021-08-18T01:17:37.045520Z" 46 | } 47 | }, 48 | "source": [ 49 | "## configuration, tokenizer, model 생성\n", 50 | "\n", 51 | "* `VAEConfig` class:\n", 52 | " * 모델을 구성하기 위해 필요한 정보(`hidden_dim`, `num_layers` 등)들이 담긴 class입니다.\n", 53 | " * 자세한 코드는 [`laiddmg/models/vae/configuration.py`](https://github.com/ilguyi/LAIDD-molecular-generation/blob/main/laiddmg/models/vae/configuration.py)에 나와 있습니다.\n", 54 | "* `Tokenizer` class:\n", 55 | " * `str`으로 된 SMILES 데이터를 미리 정의해둔 `vocab_dict`에 맞춰 token data(`int`)로 바꿔주는 역할을 합니다.\n", 56 | " * 자세한 코드는 [`laiddmg/tokenization_utils.py`](https://github.com/ilguyi/LAIDD-molecular-generation/blob/main/laiddmg/tokenization_utils.py)에 나와 있습니다.\n", 57 | "* `VAEModel` class:\n", 58 | " * 실제 모델을 만들어주는 클래스입니다.\n", 59 | " * `PyTorch`에서 제공하는 표준적인 방법으로 클래스를 구성하였습니다. tutorial은 [https://pytorch.org/tutorials/beginner/basics/buildmodel_tutorial.html](https://pytorch.org/tutorials/beginner/basics/buildmodel_tutorial.html) 여기서 확인할 수 있습니다.\n", 60 | " * 이 모델은 Rafael Gómez-Bombarelli, et. al., [Automatic Chemical Design Using a Data-Driven Continuous Representation of Molecules](https://pubs.acs.org/doi/10.1021/acscentsci.7b00572)을 바탕으로 작성하였습니다.\n", 61 | " * 자세한 코드는 [`laiddmg/models/vae/modeling.py`](https://github.com/ilguyi/LAIDD-molecular-generation/blob/main/laiddmg/models/vae/modeling.py)에 나와 있습니다." 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": null, 67 | "metadata": { 68 | "ExecuteTime": { 69 | "end_time": "2021-08-30T11:35:37.744102Z", 70 | "start_time": "2021-08-30T11:35:37.682233Z" 71 | } 72 | }, 73 | "outputs": [], 74 | "source": [ 75 | "model_type = 'vae'\n", 76 | "config = VAEConfig()\n", 77 | "tokenizer = Tokenizer()\n", 78 | "model = VAEModel(config)" 79 | ] 80 | }, 81 | { 82 | "cell_type": "markdown", 83 | "metadata": {}, 84 | "source": [ 85 | "#### Print model configuration" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": null, 91 | "metadata": { 92 | "ExecuteTime": { 93 | "end_time": "2021-08-30T11:35:39.779688Z", 94 | "start_time": "2021-08-30T11:35:39.772649Z" 95 | } 96 | }, 97 | "outputs": [], 98 | "source": [ 99 | "for k, v in config.__dict__.items():\n", 100 | " print(f'{k}: {v}')" 101 | ] 102 | }, 103 | { 104 | "cell_type": "markdown", 105 | "metadata": { 106 | "ExecuteTime": { 107 | "end_time": "2021-08-18T01:20:33.876741Z", 108 | "start_time": "2021-08-18T01:20:33.874348Z" 109 | } 110 | }, 111 | "source": [ 112 | "#### How to use tokenizer" 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": null, 118 | "metadata": { 119 | "ExecuteTime": { 120 | "end_time": "2021-08-30T11:35:41.577068Z", 121 | "start_time": "2021-08-30T11:35:41.562805Z" 122 | } 123 | }, 124 | "outputs": [], 125 | "source": [ 126 | "tokenizer.vocab" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": null, 132 | "metadata": { 133 | "ExecuteTime": { 134 | "end_time": "2021-08-30T11:35:41.792639Z", 135 | "start_time": "2021-08-30T11:35:41.786934Z" 136 | } 137 | }, 138 | "outputs": [], 139 | "source": [ 140 | "smiles = 'c1ccccc1' # 벤젠\n", 141 | "tokenizer(smiles)" 142 | ] 143 | }, 144 | { 145 | "cell_type": "markdown", 146 | "metadata": { 147 | "ExecuteTime": { 148 | "end_time": "2021-08-18T01:22:07.242192Z", 149 | "start_time": "2021-08-18T01:22:07.239650Z" 150 | } 151 | }, 152 | "source": [ 153 | "#### Print model's informations" 154 | ] 155 | }, 156 | { 157 | "cell_type": "code", 158 | "execution_count": null, 159 | "metadata": { 160 | "ExecuteTime": { 161 | "end_time": "2021-08-30T11:35:44.287604Z", 162 | "start_time": "2021-08-30T11:35:44.282092Z" 163 | } 164 | }, 165 | "outputs": [], 166 | "source": [ 167 | "model" 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": null, 173 | "metadata": { 174 | "ExecuteTime": { 175 | "end_time": "2021-08-30T11:35:44.617501Z", 176 | "start_time": "2021-08-30T11:35:44.612641Z" 177 | } 178 | }, 179 | "outputs": [], 180 | "source": [ 181 | "print(f'model type: {config.model_type}')\n", 182 | "print(f'model device: {model.device}')\n", 183 | "print(f'model dtype: {model.dtype}')\n", 184 | "print(f'number of training parameters: {model.num_parameters()}')" 185 | ] 186 | }, 187 | { 188 | "cell_type": "markdown", 189 | "metadata": { 190 | "ExecuteTime": { 191 | "end_time": "2021-08-18T04:41:43.868416Z", 192 | "start_time": "2021-08-18T04:41:43.864883Z" 193 | } 194 | }, 195 | "source": [ 196 | "## Model 체크\n", 197 | "\n", 198 | "* model의 input으로는 `input_ids`와 `lengths`가 필요합니다. \n", 199 | "* `input_ids`는 `tokenizer`를 통해 SMILES character를 각각 token number로 바꾼 결과입니다.\n", 200 | "* `lengths`는 각 문장(각 SMILES 데이터)의 sequence length 정보입니다." 201 | ] 202 | }, 203 | { 204 | "cell_type": "code", 205 | "execution_count": null, 206 | "metadata": { 207 | "ExecuteTime": { 208 | "end_time": "2021-08-30T11:35:45.537171Z", 209 | "start_time": "2021-08-30T11:35:45.532126Z" 210 | } 211 | }, 212 | "outputs": [], 213 | "source": [ 214 | "smiles = 'c1ccccc1'\n", 215 | "inputs = tokenizer(smiles)\n", 216 | "inputs" 217 | ] 218 | }, 219 | { 220 | "cell_type": "code", 221 | "execution_count": null, 222 | "metadata": { 223 | "ExecuteTime": { 224 | "end_time": "2021-08-30T11:35:46.149366Z", 225 | "start_time": "2021-08-30T11:35:46.101540Z" 226 | } 227 | }, 228 | "outputs": [], 229 | "source": [ 230 | "# outputs, z_mu, z_logvar = model(**inputs)\n", 231 | "outputs, z_mu, z_logvar = model(input_ids=inputs['input_ids'],\n", 232 | " lengths=inputs['lengths'])" 233 | ] 234 | }, 235 | { 236 | "cell_type": "code", 237 | "execution_count": null, 238 | "metadata": { 239 | "ExecuteTime": { 240 | "end_time": "2021-08-30T11:35:46.371932Z", 241 | "start_time": "2021-08-30T11:35:46.367348Z" 242 | } 243 | }, 244 | "outputs": [], 245 | "source": [ 246 | "print(f'outputs shape: {outputs.shape}')\n", 247 | "print(f'latent vector mean shape: {z_mu.shape}')\n", 248 | "print(f'latent vector log variance shape: {z_logvar.shape}')" 249 | ] 250 | }, 251 | { 252 | "cell_type": "markdown", 253 | "metadata": {}, 254 | "source": [ 255 | "### VAE architecture\n", 256 | "\n", 257 | "\"vae-gaussian\"" 258 | ] 259 | }, 260 | { 261 | "cell_type": "markdown", 262 | "metadata": {}, 263 | "source": [ 264 | "### Encoder 살펴보기\n", 265 | "\n", 266 | "* 자세한 코드는 [`laiddmg/models/vae/modeling.py`](https://github.com/ilguyi/LAIDD-molecular-generation/blob/main/laiddmg/models/vae/modeling.py)에 나와 있습니다." 267 | ] 268 | }, 269 | { 270 | "cell_type": "code", 271 | "execution_count": null, 272 | "metadata": { 273 | "ExecuteTime": { 274 | "end_time": "2021-08-30T11:36:33.529206Z", 275 | "start_time": "2021-08-30T11:36:33.522604Z" 276 | } 277 | }, 278 | "outputs": [], 279 | "source": [ 280 | "class Encoder(nn.Module):\n", 281 | "\n", 282 | " def __init__(self, config: VAEConfig, embeddings: nn.Module = None):\n", 283 | " super(Encoder, self).__init__()\n", 284 | " self.vocab_size = config.vocab_size\n", 285 | " self.embedding_dim = config.embedding_dim\n", 286 | " self.encoder_hidden_dim = config.encoder_hidden_dim\n", 287 | " self.encoder_num_layers = config.encoder_num_layers\n", 288 | " self.encoder_dropout = config.encoder_dropout\n", 289 | " self.latent_dim = config.latent_dim\n", 290 | " self.padding_value = config.padding_value\n", 291 | "\n", 292 | " if embeddings is not None:\n", 293 | " self.embeddings = embeddings\n", 294 | " else:\n", 295 | " self.embeddings = nn.Embedding(self.vocab_size, self.embedding_dim,\n", 296 | " padding_idx=self.padding_value)\n", 297 | "\n", 298 | " self.gru = nn.GRU(self.embedding_dim,\n", 299 | " self.encoder_hidden_dim,\n", 300 | " self.encoder_num_layers,\n", 301 | " batch_first=True,\n", 302 | " dropout=self.encoder_dropout if self.encoder_num_layers > 1 else 0)\n", 303 | " self.fc = nn.Linear(self.encoder_hidden_dim, self.latent_dim * 2)\n", 304 | "\n", 305 | " def forward(\n", 306 | " self,\n", 307 | " input_ids: torch.Tensor, # (batch_size, seq_len)\n", 308 | " lengths: torch.Tensor, # (batch_size,)\n", 309 | " **kwargs,\n", 310 | " ) -> Tuple[torch.Tensor]:\n", 311 | " x = self.embeddings(input_ids) # x: (batch_size, seq_len, embedding_dim)\n", 312 | " x = rnn_utils.pack_padded_sequence(\n", 313 | " x,\n", 314 | " lengths.cpu(),\n", 315 | " batch_first=True,\n", 316 | " enforce_sorted=False,\n", 317 | " )\n", 318 | " _, hiddens = self.gru(x, None) # hiddens: (num_layers, batch_size, encoder_hidden_dim)\n", 319 | "\n", 320 | " hiddens = hiddens[-1] # hiddens: (batch_size, encoder_hidden_dim) for last layer\n", 321 | "\n", 322 | " z_mu, z_logvar = torch.split(self.fc(hiddens), self.latent_dim, dim=-1)\n", 323 | " # z_mu, z_logvar: (batch_size, latent_dim)\n", 324 | "\n", 325 | " return z_mu, z_logvar" 326 | ] 327 | }, 328 | { 329 | "cell_type": "code", 330 | "execution_count": null, 331 | "metadata": { 332 | "ExecuteTime": { 333 | "end_time": "2021-08-30T11:36:33.948313Z", 334 | "start_time": "2021-08-30T11:36:33.932352Z" 335 | } 336 | }, 337 | "outputs": [], 338 | "source": [ 339 | "encoder = Encoder(config)\n", 340 | "z_mu, z_logvar = encoder(**inputs)" 341 | ] 342 | }, 343 | { 344 | "cell_type": "code", 345 | "execution_count": null, 346 | "metadata": { 347 | "ExecuteTime": { 348 | "end_time": "2021-08-30T11:35:49.095925Z", 349 | "start_time": "2021-08-30T11:35:49.091569Z" 350 | } 351 | }, 352 | "outputs": [], 353 | "source": [ 354 | "print(f'latent vector mean shape {z_mu.shape}')\n", 355 | "print(f'latent vector log variance shape {z_logvar.shape}')" 356 | ] 357 | }, 358 | { 359 | "cell_type": "markdown", 360 | "metadata": {}, 361 | "source": [ 362 | "### Decoder 살펴보기\n", 363 | "\n", 364 | "* 자세한 코드는 [`laiddmg/models/vae/modeling.py`](https://github.com/ilguyi/LAIDD-molecular-generation/blob/main/laiddmg/models/vae/modeling.py)에 나와 있습니다." 365 | ] 366 | }, 367 | { 368 | "cell_type": "code", 369 | "execution_count": null, 370 | "metadata": { 371 | "ExecuteTime": { 372 | "end_time": "2021-08-30T11:36:54.571398Z", 373 | "start_time": "2021-08-30T11:36:54.558123Z" 374 | } 375 | }, 376 | "outputs": [], 377 | "source": [ 378 | "class Decoder(nn.Module):\n", 379 | "\n", 380 | " def __init__(self, config: VAEConfig, embeddings: nn.Module = None):\n", 381 | " super(Decoder, self).__init__()\n", 382 | " self.vocab_size = config.vocab_size\n", 383 | " self.embedding_dim = config.embedding_dim\n", 384 | " self.latent_dim = config.latent_dim\n", 385 | " self.decoder_hidden_dim = config.decoder_hidden_dim\n", 386 | " self.decoder_num_layers = config.decoder_num_layers\n", 387 | " self.decoder_dropout = config.decoder_dropout\n", 388 | " self.input_dim = self.embedding_dim + self.latent_dim\n", 389 | " self.output_dim = config.vocab_size\n", 390 | " self.padding_value = config.padding_value\n", 391 | "\n", 392 | " if embeddings is not None:\n", 393 | " self.embeddings = embeddings\n", 394 | " else:\n", 395 | " self.embeddings = nn.Embedding(self.vocab_size, self.embedding_dim,\n", 396 | " padding_idx=self.padding_value)\n", 397 | "\n", 398 | " self.gru = nn.GRU(self.input_dim,\n", 399 | " self.decoder_hidden_dim,\n", 400 | " self.decoder_num_layers,\n", 401 | " batch_first=True,\n", 402 | " dropout=self.decoder_dropout if self.decoder_num_layers > 1 else 0)\n", 403 | " self.z2hidden = nn.Linear(self.latent_dim, self.decoder_hidden_dim)\n", 404 | " self.fc = nn.Linear(self.decoder_hidden_dim, self.output_dim)\n", 405 | "\n", 406 | " def forward(\n", 407 | " self,\n", 408 | " input_ids: torch.Tensor, # (batch_size, seq_len)\n", 409 | " lengths: torch.Tensor, # (batch_size,)\n", 410 | " z: torch.Tensor, # (batch_size, latent_dim)\n", 411 | " **kwargs,\n", 412 | " ) -> Tuple[torch.Tensor]:\n", 413 | " x = self.embeddings(input_ids) # x: (batch_size, seq_len, embedding_dim)\n", 414 | " hiddens = self.z2hidden(z) # hiddens: (batch_size, decoder_hidden_dim)\n", 415 | " hiddens = hiddens.unsqueeze(0).repeat(self.decoder_num_layers, 1, 1)\n", 416 | " # hiddens: (num_layers, batch_size, decoder_hidden_dim)\n", 417 | "\n", 418 | " z_ = z.unsqueeze(1).repeat(1, x.shape[1], 1) # z: (batch_size, seq_len, latent_dim)\n", 419 | " x = torch.cat((x, z_), dim=-1) # x: (batch_size, seq_len, embedding_dim + latent_dim)\n", 420 | "\n", 421 | " x = rnn_utils.pack_padded_sequence(\n", 422 | " x,\n", 423 | " lengths.cpu(),\n", 424 | " batch_first=True,\n", 425 | " enforce_sorted=False\n", 426 | " )\n", 427 | " x, _ = self.gru(x, hiddens)\n", 428 | " x, _ = rnn_utils.pad_packed_sequence(\n", 429 | " x,\n", 430 | " batch_first=True,\n", 431 | " ) # x: (batch_size, seq_len, hidden_dim)\n", 432 | " outputs = self.fc(x) # outputs: (batch_size, seq_len, vocab_size)\n", 433 | "\n", 434 | " return outputs" 435 | ] 436 | }, 437 | { 438 | "cell_type": "code", 439 | "execution_count": null, 440 | "metadata": { 441 | "ExecuteTime": { 442 | "end_time": "2021-08-30T11:36:55.008243Z", 443 | "start_time": "2021-08-30T11:36:54.936948Z" 444 | } 445 | }, 446 | "outputs": [], 447 | "source": [ 448 | "decoder = Decoder(config)\n", 449 | "outputs = decoder(**inputs, z=torch.randn(1, 128))" 450 | ] 451 | }, 452 | { 453 | "cell_type": "code", 454 | "execution_count": null, 455 | "metadata": { 456 | "ExecuteTime": { 457 | "end_time": "2021-08-30T11:36:55.596778Z", 458 | "start_time": "2021-08-30T11:36:55.592063Z" 459 | } 460 | }, 461 | "outputs": [], 462 | "source": [ 463 | "print(f'decoder outputs shape: {outputs.shape}')\n", 464 | "print(f'input_ids shape: {inputs[\"input_ids\"].shape}')" 465 | ] 466 | }, 467 | { 468 | "cell_type": "markdown", 469 | "metadata": { 470 | "ExecuteTime": { 471 | "end_time": "2021-08-20T05:35:48.564754Z", 472 | "start_time": "2021-08-20T05:35:48.558789Z" 473 | } 474 | }, 475 | "source": [ 476 | "### VAE model class\n", 477 | "\n", 478 | "* 자세한 코드는 [`laiddmg/models/vae/modeling.py`](https://github.com/ilguyi/LAIDD-molecular-generation/blob/main/laiddmg/models/vae/modeling.py)에 나와 있습니다.\n", 479 | "\n", 480 | "\n", 481 | "### VAE model step by step으로 알아보기\n", 482 | "\n", 483 | "1. 인풋 데이터(`input_ids`)를 Encoder(`encoder`)에 넣는다.\n", 484 | "2. `z_mu`와 `z_logvar`값에 reparametrization trick을 적용하여 실제 sample된 latent vector `z`를 만든다.\n", 485 | "3. 인코딩된 정보 latent vector `z`와 인풋 데이터(`encoder`의 인풋과 같다)를 이용하여 Decoder(`decoder`)에 적용시킨다." 486 | ] 487 | }, 488 | { 489 | "cell_type": "code", 490 | "execution_count": null, 491 | "metadata": { 492 | "ExecuteTime": { 493 | "end_time": "2021-08-30T11:36:58.192456Z", 494 | "start_time": "2021-08-30T11:36:58.182662Z" 495 | } 496 | }, 497 | "outputs": [], 498 | "source": [ 499 | "class _VAEModel(nn.Module):\n", 500 | "\n", 501 | " def __init__(self, config: VAEConfig):\n", 502 | " super(VAEModel, self).__init__()\n", 503 | " self.config = config\n", 504 | " self.vocab_size = config.vocab_size\n", 505 | " self.embedding_dim = config.embedding_dim\n", 506 | " self.latent_dim = config.latent_dim\n", 507 | " self.padding_value = config.padding_value\n", 508 | "\n", 509 | " self.embeddings = nn.Embedding(self.vocab_size, self.embedding_dim,\n", 510 | " padding_idx=self.padding_value)\n", 511 | "\n", 512 | " self.encoder = Encoder(self.config, self.embeddings)\n", 513 | " self.decoder = Decoder(self.config, self.embeddings)\n", 514 | "\n", 515 | " def reparameterize(self, mean, logvar):\n", 516 | " epsilon = torch.rand_like(mean)\n", 517 | " z = epsilon * torch.exp(logvar * .5) + mean # mean, logvar, z: (batch_size, latent_dim)\n", 518 | "\n", 519 | " return z\n", 520 | "\n", 521 | " def forward(\n", 522 | " self,\n", 523 | " input_ids: torch.Tensor, # (batch_size, seq_len)\n", 524 | " lengths: torch.Tensor, # (batch_size,)\n", 525 | " **kwargs,\n", 526 | " ) -> Tuple[torch.Tensor]:\n", 527 | " z_mu, z_logvar = self.encoder(input_ids, lengths)\n", 528 | " z = self.reparameterize(z_mu, z_logvar) # z: (batch_size, latent_dim)\n", 529 | " y = self.decoder(input_ids, lengths, z) # y: (batch_size, seq_len, vocab_size)\n", 530 | "\n", 531 | " return y, z_mu, z_logvar" 532 | ] 533 | }, 534 | { 535 | "cell_type": "markdown", 536 | "metadata": { 537 | "ExecuteTime": { 538 | "end_time": "2021-08-20T07:24:26.172053Z", 539 | "start_time": "2021-08-20T07:24:26.158807Z" 540 | } 541 | }, 542 | "source": [ 543 | "#### 1. 인풋 데이터(`input_ids`)를 Encoder(`encoder`)에 넣는다." 544 | ] 545 | }, 546 | { 547 | "cell_type": "code", 548 | "execution_count": null, 549 | "metadata": { 550 | "ExecuteTime": { 551 | "end_time": "2021-08-30T11:36:59.991359Z", 552 | "start_time": "2021-08-30T11:36:59.972081Z" 553 | } 554 | }, 555 | "outputs": [], 556 | "source": [ 557 | "z_mu, z_logvar = model.encoder(**inputs)" 558 | ] 559 | }, 560 | { 561 | "cell_type": "code", 562 | "execution_count": null, 563 | "metadata": { 564 | "ExecuteTime": { 565 | "end_time": "2021-08-30T11:37:00.215995Z", 566 | "start_time": "2021-08-30T11:37:00.212210Z" 567 | } 568 | }, 569 | "outputs": [], 570 | "source": [ 571 | "print(f'latent vector mean shape: {z_mu.shape}')\n", 572 | "print(f'latent vector log variance shape: {z_logvar.shape}')" 573 | ] 574 | }, 575 | { 576 | "cell_type": "markdown", 577 | "metadata": {}, 578 | "source": [ 579 | "#### 2. `z_mu`와 `z_logvar`값에 reparametrization trick을 적용하여 실제 sample된 latent vector `z`를 만든다." 580 | ] 581 | }, 582 | { 583 | "cell_type": "code", 584 | "execution_count": null, 585 | "metadata": { 586 | "ExecuteTime": { 587 | "end_time": "2021-08-30T11:37:02.151208Z", 588 | "start_time": "2021-08-30T11:37:02.148256Z" 589 | } 590 | }, 591 | "outputs": [], 592 | "source": [ 593 | "epsilon = torch.rand_like(z_mu)\n", 594 | "z = epsilon * torch.exp(z_logvar * .5) + z_mu" 595 | ] 596 | }, 597 | { 598 | "cell_type": "code", 599 | "execution_count": null, 600 | "metadata": { 601 | "ExecuteTime": { 602 | "end_time": "2021-08-30T11:37:02.355535Z", 603 | "start_time": "2021-08-30T11:37:02.351740Z" 604 | } 605 | }, 606 | "outputs": [], 607 | "source": [ 608 | "print(f'latent vector shape: {z.shape}')" 609 | ] 610 | }, 611 | { 612 | "cell_type": "markdown", 613 | "metadata": { 614 | "ExecuteTime": { 615 | "end_time": "2021-08-20T07:24:26.172053Z", 616 | "start_time": "2021-08-20T07:24:26.158807Z" 617 | } 618 | }, 619 | "source": [ 620 | "#### 3. 인코딩된 정보 latent vector `z`와 인풋 데이터(`encoder의 인풋과 같다)를 이용하여 decoder에 적용시킨다." 621 | ] 622 | }, 623 | { 624 | "cell_type": "code", 625 | "execution_count": null, 626 | "metadata": { 627 | "ExecuteTime": { 628 | "end_time": "2021-08-30T11:37:03.732705Z", 629 | "start_time": "2021-08-30T11:37:03.682298Z" 630 | } 631 | }, 632 | "outputs": [], 633 | "source": [ 634 | "outputs = model.decoder(**inputs, z=z)" 635 | ] 636 | }, 637 | { 638 | "cell_type": "code", 639 | "execution_count": null, 640 | "metadata": { 641 | "ExecuteTime": { 642 | "end_time": "2021-08-30T11:37:04.125578Z", 643 | "start_time": "2021-08-30T11:37:04.122239Z" 644 | } 645 | }, 646 | "outputs": [], 647 | "source": [ 648 | "print(f'outputs shape: {outputs.shape}')" 649 | ] 650 | }, 651 | { 652 | "cell_type": "markdown", 653 | "metadata": { 654 | "ExecuteTime": { 655 | "end_time": "2021-08-18T05:15:57.690026Z", 656 | "start_time": "2021-08-18T05:15:57.687056Z" 657 | } 658 | }, 659 | "source": [ 660 | "## Data 얻기\n", 661 | "\n", 662 | "* [Molecular Sets (MOSES): A benchmarking platform for molecular generation models](https://github.com/molecularsets/moses)에서 사용한 ZINC데이터를 random sampling을 통해 `train : test = 250000 : 30000`으로 나누었습니다.\n", 663 | "* 실제 데이터 파일 경로는 아래와 같습니다.\n", 664 | " * [`datasets/moses`](https://github.com/ilguyi/LAIDD-molecular-generation/blob/main/datasets/moses)\n", 665 | "* `get_rawdataset`함수를 이용하여 얻은 데이터는 각 항목이 SMILES `str`데이터로 이루어진 `np.ndarray`입니다.\n", 666 | "* 이 rawdataset을 사용하기 편하게 `custom Dataset` class를 만들었습니다.\n", 667 | " * `custom Dataset`을 만드는 간단한 예제는 [PyTorch tutorial](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html#creating-a-custom-dataset-for-your-files)에 있습니다.\n", 668 | " * 자세한 코드는 [`laiddmg/ᅟdatasets.py`](https://github.com/ilguyi/LAIDD-molecular-generation/blob/main/laiddmg/datasets.py)에 나와 있습니다." 669 | ] 670 | }, 671 | { 672 | "cell_type": "code", 673 | "execution_count": null, 674 | "metadata": { 675 | "ExecuteTime": { 676 | "end_time": "2021-08-30T11:37:06.193224Z", 677 | "start_time": "2021-08-30T11:37:06.002379Z" 678 | } 679 | }, 680 | "outputs": [], 681 | "source": [ 682 | "train = get_rawdataset('train')" 683 | ] 684 | }, 685 | { 686 | "cell_type": "code", 687 | "execution_count": null, 688 | "metadata": { 689 | "ExecuteTime": { 690 | "end_time": "2021-08-30T11:37:06.197988Z", 691 | "start_time": "2021-08-30T11:37:06.194621Z" 692 | } 693 | }, 694 | "outputs": [], 695 | "source": [ 696 | "train" 697 | ] 698 | }, 699 | { 700 | "cell_type": "code", 701 | "execution_count": null, 702 | "metadata": { 703 | "ExecuteTime": { 704 | "end_time": "2021-08-30T11:37:06.386103Z", 705 | "start_time": "2021-08-30T11:37:06.382678Z" 706 | } 707 | }, 708 | "outputs": [], 709 | "source": [ 710 | "print(f'number of training dataset: {len(train)}')\n", 711 | "print(f'raw data type: {type(train[0])}')" 712 | ] 713 | }, 714 | { 715 | "cell_type": "markdown", 716 | "metadata": {}, 717 | "source": [ 718 | "#### `model`에 sample data 적용해보기" 719 | ] 720 | }, 721 | { 722 | "cell_type": "code", 723 | "execution_count": null, 724 | "metadata": { 725 | "ExecuteTime": { 726 | "end_time": "2021-08-30T11:37:08.408947Z", 727 | "start_time": "2021-08-30T11:37:08.401696Z" 728 | } 729 | }, 730 | "outputs": [], 731 | "source": [ 732 | "sampled_data = train[:4]\n", 733 | "inputs = tokenizer(sampled_data)\n", 734 | "inputs" 735 | ] 736 | }, 737 | { 738 | "cell_type": "markdown", 739 | "metadata": {}, 740 | "source": [ 741 | "#### `inputs`를 `model`의 입력값으로 넣기" 742 | ] 743 | }, 744 | { 745 | "cell_type": "code", 746 | "execution_count": null, 747 | "metadata": { 748 | "ExecuteTime": { 749 | "end_time": "2021-08-30T11:37:09.984904Z", 750 | "start_time": "2021-08-30T11:37:08.956622Z" 751 | } 752 | }, 753 | "outputs": [], 754 | "source": [ 755 | "outputs, z_mu, z_logvar = model(**inputs)" 756 | ] 757 | }, 758 | { 759 | "cell_type": "code", 760 | "execution_count": null, 761 | "metadata": { 762 | "ExecuteTime": { 763 | "end_time": "2021-08-30T11:37:09.989282Z", 764 | "start_time": "2021-08-30T11:37:09.986266Z" 765 | } 766 | }, 767 | "outputs": [], 768 | "source": [ 769 | "print(f'outputs shape: {outputs.shape}')\n", 770 | "print(f'latent vector mean shape: {z_mu.shape}')\n", 771 | "print(f'latent vector log variance shape: {z_logvar.shape}')" 772 | ] 773 | }, 774 | { 775 | "cell_type": "markdown", 776 | "metadata": { 777 | "ExecuteTime": { 778 | "end_time": "2021-08-18T05:43:51.928351Z", 779 | "start_time": "2021-08-18T05:43:51.924835Z" 780 | } 781 | }, 782 | "source": [ 783 | "### PyTorch `Dataset`, `DataLoader` 얻기" 784 | ] 785 | }, 786 | { 787 | "cell_type": "code", 788 | "execution_count": null, 789 | "metadata": { 790 | "ExecuteTime": { 791 | "end_time": "2021-08-30T11:37:09.992050Z", 792 | "start_time": "2021-08-30T11:37:09.990274Z" 793 | } 794 | }, 795 | "outputs": [], 796 | "source": [ 797 | "from torch.utils.data import Dataset, DataLoader" 798 | ] 799 | }, 800 | { 801 | "cell_type": "code", 802 | "execution_count": null, 803 | "metadata": { 804 | "ExecuteTime": { 805 | "end_time": "2021-08-30T11:37:09.994402Z", 806 | "start_time": "2021-08-30T11:37:09.992842Z" 807 | } 808 | }, 809 | "outputs": [], 810 | "source": [ 811 | "train_dataset = get_dataset(train, tokenizer)" 812 | ] 813 | }, 814 | { 815 | "cell_type": "code", 816 | "execution_count": null, 817 | "metadata": { 818 | "ExecuteTime": { 819 | "end_time": "2021-08-30T11:37:10.013066Z", 820 | "start_time": "2021-08-30T11:37:10.007338Z" 821 | } 822 | }, 823 | "outputs": [], 824 | "source": [ 825 | "train_dataset[1000]" 826 | ] 827 | }, 828 | { 829 | "cell_type": "markdown", 830 | "metadata": { 831 | "ExecuteTime": { 832 | "end_time": "2021-08-18T06:53:56.747987Z", 833 | "start_time": "2021-08-18T06:53:56.743885Z" 834 | } 835 | }, 836 | "source": [ 837 | "#### `input_id`와 `target`의 관계\n", 838 | "\n", 839 | "RNN을 이용한 생성모델(generative model)은 language model의 학습방법을 이용한다.\n", 840 | "Language model은 간단하게 이야기하면 다음 단어를 예측하는 모델이다.\n", 841 | "다음 단어를 예측한다는 뜻은 RNN 그림을 보면 쉽게 이해할 수 있다.\n", 842 | "\n", 843 | "![RNN-input-target](https://user-images.githubusercontent.com/11681225/129859647-af31934a-0eea-4ad8-9a85-2d3c2a75f517.jpeg)\n", 844 | "\n", 845 | "위와 같이 input data의 token이 하나씩 이동한 것이 target이 되는 것이다." 846 | ] 847 | }, 848 | { 849 | "cell_type": "code", 850 | "execution_count": null, 851 | "metadata": { 852 | "ExecuteTime": { 853 | "end_time": "2021-08-30T11:37:12.906199Z", 854 | "start_time": "2021-08-30T11:37:12.902545Z" 855 | } 856 | }, 857 | "outputs": [], 858 | "source": [ 859 | "def _pad_sequence(data: List[torch.Tensor],\n", 860 | " padding_value: int = 0) -> torch.Tensor:\n", 861 | " return rnn_utils.pad_sequence(data,\n", 862 | " batch_first=True,\n", 863 | " padding_value=padding_value)" 864 | ] 865 | }, 866 | { 867 | "cell_type": "code", 868 | "execution_count": null, 869 | "metadata": { 870 | "ExecuteTime": { 871 | "end_time": "2021-08-30T11:37:13.350004Z", 872 | "start_time": "2021-08-30T11:37:13.342408Z" 873 | } 874 | }, 875 | "outputs": [], 876 | "source": [ 877 | "def _collate_fn(batch: List[Dict[str, Union[torch.Tensor, str, int]]],\n", 878 | " **kwargs) -> Dict[str, Union[torch.Tensor, List[str], List[int]]]:\n", 879 | "\n", 880 | " indexes = [item['index'] for item in batch]\n", 881 | " smiles = [item['smiles'] for item in batch]\n", 882 | " input_ids = [item['input_id'] for item in batch]\n", 883 | " targets = [item['target'] for item in batch]\n", 884 | " lengths = [item['length'] for item in batch]\n", 885 | "\n", 886 | " padding_value = tokenizer.padding_value\n", 887 | " input_ids = _pad_sequence(input_ids, padding_value)\n", 888 | " targets = _pad_sequence(targets, padding_value)\n", 889 | " lengths = torch.LongTensor(lengths)\n", 890 | "\n", 891 | " return {'input_ids': input_ids,\n", 892 | " 'targets': targets,\n", 893 | " 'lengths': lengths,\n", 894 | " 'smiles': smiles,\n", 895 | " 'indexes': indexes}" 896 | ] 897 | }, 898 | { 899 | "cell_type": "code", 900 | "execution_count": null, 901 | "metadata": { 902 | "ExecuteTime": { 903 | "end_time": "2021-08-30T11:37:13.984729Z", 904 | "start_time": "2021-08-30T11:37:13.981792Z" 905 | } 906 | }, 907 | "outputs": [], 908 | "source": [ 909 | "train_dataloader = DataLoader(train_dataset,\n", 910 | " batch_size=4,\n", 911 | " shuffle=True,\n", 912 | " collate_fn=_collate_fn,\n", 913 | " )" 914 | ] 915 | }, 916 | { 917 | "cell_type": "markdown", 918 | "metadata": { 919 | "ExecuteTime": { 920 | "end_time": "2021-08-18T23:33:34.497465Z", 921 | "start_time": "2021-08-18T23:33:34.494157Z" 922 | } 923 | }, 924 | "source": [ 925 | "### Train without `Trainer` class\n", 926 | "\n", 927 | "* 실제 사용할 수 있게 패키징한 코드에서는 `Trainer` class를 만들어 사용하기 편리하게 모듈화 시켰습니다.\n", 928 | "* 하지만 해당 Jupyter notebook은 이해를 돕기위해 모듈화 되어 있는 코드를 풀어서 블록 단위로 나타내었습니다.\n", 929 | "* `Trainer`에 관련된 자세한 코드는 아래 링크에 있습니다.\n", 930 | " * [`laiddmg/trainer.py`](https://github.com/ilguyi/LAIDD-molecular-generation/blob/main/laiddmg/trainer.py)\n", 931 | " * [`laiddmg/models/vae/vae_trainer.py`](https://github.com/ilguyi/LAIDD-molecular-generation/blob/main/laiddmg/models/vae/vae_trainer.py)" 932 | ] 933 | }, 934 | { 935 | "cell_type": "markdown", 936 | "metadata": {}, 937 | "source": [ 938 | "#### loss function and optimizer 생성" 939 | ] 940 | }, 941 | { 942 | "cell_type": "code", 943 | "execution_count": null, 944 | "metadata": { 945 | "ExecuteTime": { 946 | "end_time": "2021-08-30T11:37:18.880424Z", 947 | "start_time": "2021-08-30T11:37:18.877125Z" 948 | } 949 | }, 950 | "outputs": [], 951 | "source": [ 952 | "training_args = EasyDict({\n", 953 | " 'output_dir': 'outputs/vae/jupyter1',\n", 954 | " 'num_train_epochs': 10,\n", 955 | " 'batch_size': 256,\n", 956 | " 'lr': 1e-3,\n", 957 | "})" 958 | ] 959 | }, 960 | { 961 | "cell_type": "code", 962 | "execution_count": null, 963 | "metadata": { 964 | "ExecuteTime": { 965 | "end_time": "2021-08-30T11:37:19.225155Z", 966 | "start_time": "2021-08-30T11:37:19.222014Z" 967 | } 968 | }, 969 | "outputs": [], 970 | "source": [ 971 | "train_dataloader = DataLoader(train_dataset,\n", 972 | " batch_size=training_args.batch_size,\n", 973 | " shuffle=True,\n", 974 | " collate_fn=_collate_fn,\n", 975 | " )" 976 | ] 977 | }, 978 | { 979 | "cell_type": "code", 980 | "execution_count": null, 981 | "metadata": { 982 | "ExecuteTime": { 983 | "end_time": "2021-08-30T11:37:19.936386Z", 984 | "start_time": "2021-08-30T11:37:19.931911Z" 985 | } 986 | }, 987 | "outputs": [], 988 | "source": [ 989 | "reconstruction_loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer.padding_value)\n", 990 | "optimizer = optim.Adam(model.parameters(), lr=training_args.lr)" 991 | ] 992 | }, 993 | { 994 | "cell_type": "markdown", 995 | "metadata": {}, 996 | "source": [ 997 | "### Training" 998 | ] 999 | }, 1000 | { 1001 | "cell_type": "code", 1002 | "execution_count": null, 1003 | "metadata": { 1004 | "ExecuteTime": { 1005 | "end_time": "2021-08-30T11:37:22.312825Z", 1006 | "start_time": "2021-08-30T11:37:22.012946Z" 1007 | } 1008 | }, 1009 | "outputs": [], 1010 | "source": [ 1011 | "device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n", 1012 | "print(device)" 1013 | ] 1014 | }, 1015 | { 1016 | "cell_type": "code", 1017 | "execution_count": null, 1018 | "metadata": { 1019 | "ExecuteTime": { 1020 | "end_time": "2021-08-30T11:37:26.176656Z", 1021 | "start_time": "2021-08-30T11:37:22.356732Z" 1022 | } 1023 | }, 1024 | "outputs": [], 1025 | "source": [ 1026 | "model = model.to(device)\n", 1027 | "print(model.device)" 1028 | ] 1029 | }, 1030 | { 1031 | "cell_type": "markdown", 1032 | "metadata": { 1033 | "ExecuteTime": { 1034 | "end_time": "2021-08-23T04:59:44.110368Z", 1035 | "start_time": "2021-08-23T04:59:44.107785Z" 1036 | } 1037 | }, 1038 | "source": [ 1039 | "### KL annealing" 1040 | ] 1041 | }, 1042 | { 1043 | "cell_type": "code", 1044 | "execution_count": null, 1045 | "metadata": { 1046 | "ExecuteTime": { 1047 | "end_time": "2021-08-30T11:37:26.452470Z", 1048 | "start_time": "2021-08-30T11:37:26.179438Z" 1049 | } 1050 | }, 1051 | "outputs": [], 1052 | "source": [ 1053 | "from utils import AnnealingSchedules\n", 1054 | "import matplotlib.pyplot as plt\n", 1055 | "%matplotlib inline" 1056 | ] 1057 | }, 1058 | { 1059 | "cell_type": "code", 1060 | "execution_count": null, 1061 | "metadata": { 1062 | "ExecuteTime": { 1063 | "end_time": "2021-08-30T11:37:29.790299Z", 1064 | "start_time": "2021-08-30T11:37:29.786902Z" 1065 | } 1066 | }, 1067 | "outputs": [], 1068 | "source": [ 1069 | "kl_annealing = AnnealingSchedules(\n", 1070 | " method='cycle_sigmoid', # cycle_linear, cycle_sigmoid, cycle_cosine\n", 1071 | " update_unit='step', # epoch, step\n", 1072 | " num_training_steps=100,\n", 1073 | " num_training_steps_per_epoch=10,\n", 1074 | " start_weight=0.0,\n", 1075 | " stop_weight=1.0,\n", 1076 | ")" 1077 | ] 1078 | }, 1079 | { 1080 | "cell_type": "code", 1081 | "execution_count": null, 1082 | "metadata": { 1083 | "ExecuteTime": { 1084 | "end_time": "2021-08-30T11:37:30.230042Z", 1085 | "start_time": "2021-08-30T11:37:30.227233Z" 1086 | } 1087 | }, 1088 | "outputs": [], 1089 | "source": [ 1090 | "kl_annealing_weight = []\n", 1091 | "for step in range(100):\n", 1092 | " kl_annealing_weight.append(kl_annealing(step)) " 1093 | ] 1094 | }, 1095 | { 1096 | "cell_type": "code", 1097 | "execution_count": null, 1098 | "metadata": { 1099 | "ExecuteTime": { 1100 | "end_time": "2021-08-30T11:37:30.680120Z", 1101 | "start_time": "2021-08-30T11:37:30.592092Z" 1102 | } 1103 | }, 1104 | "outputs": [], 1105 | "source": [ 1106 | "plt.plot(kl_annealing_weight)" 1107 | ] 1108 | }, 1109 | { 1110 | "cell_type": "markdown", 1111 | "metadata": {}, 1112 | "source": [ 1113 | "### Training" 1114 | ] 1115 | }, 1116 | { 1117 | "cell_type": "code", 1118 | "execution_count": null, 1119 | "metadata": { 1120 | "ExecuteTime": { 1121 | "end_time": "2021-08-30T11:37:32.937614Z", 1122 | "start_time": "2021-08-30T11:37:32.933446Z" 1123 | } 1124 | }, 1125 | "outputs": [], 1126 | "source": [ 1127 | "def save_model(epoch: int, global_step: int, model: nn.Module):\n", 1128 | " checkpoint_dir = os.path.join(training_args.output_dir)\n", 1129 | " os.makedirs(checkpoint_dir, exist_ok=True)\n", 1130 | " ckpt_name = f'ckpt_{epoch:03d}.pt'\n", 1131 | " ckpt_path = os.path.join(checkpoint_dir, ckpt_name)\n", 1132 | " \n", 1133 | " torch.save({'epoch': epoch,\n", 1134 | " 'global_step': global_step,\n", 1135 | " 'model_state_dict': model.state_dict()},\n", 1136 | " ckpt_path)\n", 1137 | " print(f'saved {model.config.model_type} model at epoch {epoch}.')" 1138 | ] 1139 | }, 1140 | { 1141 | "cell_type": "markdown", 1142 | "metadata": {}, 1143 | "source": [ 1144 | "### KL Divergence Loss\n", 1145 | "\n", 1146 | "![vae kl loss](https://user-images.githubusercontent.com/11681225/131319712-3ca94a3c-0f72-4b9b-9a37-1b91e53c608c.jpeg)" 1147 | ] 1148 | }, 1149 | { 1150 | "cell_type": "code", 1151 | "execution_count": null, 1152 | "metadata": { 1153 | "ExecuteTime": { 1154 | "end_time": "2021-08-30T11:38:05.355634Z", 1155 | "start_time": "2021-08-30T11:37:57.376633Z" 1156 | }, 1157 | "scrolled": true 1158 | }, 1159 | "outputs": [], 1160 | "source": [ 1161 | "model.train()\n", 1162 | "global_step = 0\n", 1163 | "\n", 1164 | "kl_annealing = AnnealingSchedules(\n", 1165 | " method='cycle_linear', # cycle_linear, cycle_sigmoid, cycle_cosine\n", 1166 | " update_unit='epoch', # epoch, step\n", 1167 | " num_training_steps=len(train_dataloader) * training_args.num_train_epochs,\n", 1168 | " num_training_steps_per_epoch=len(train_dataloader),\n", 1169 | " start_weight=0.0,\n", 1170 | " stop_weight=0.05,\n", 1171 | ")\n", 1172 | "\n", 1173 | "for epoch in range(1, training_args.num_train_epochs + 1):\n", 1174 | " print(f'\\nStart training: {epoch} Epoch\\n')\n", 1175 | " \n", 1176 | " for i, data in enumerate(train_dataloader, 1):\n", 1177 | " optimizer.zero_grad()\n", 1178 | " \n", 1179 | " data['input_ids'] = data['input_ids'].to(device)\n", 1180 | " data['targets'] = data['targets'].to(device)\n", 1181 | " outputs, z_mu, z_logvar = model(data['input_ids'], data['lengths'])\n", 1182 | " \n", 1183 | " reconstruction_loss = reconstruction_loss_fn(outputs.view(-1, outputs.shape[-1]),\n", 1184 | " data['targets'].view(-1))\n", 1185 | " \n", 1186 | " kl_loss = .5 * (torch.exp(z_logvar) + z_mu**2 - 1. - z_logvar).sum(1).mean()\n", 1187 | " \n", 1188 | " kl_annealing_weight = kl_annealing(global_step)\n", 1189 | " \n", 1190 | " total_loss = reconstruction_loss + kl_annealing_weight * kl_loss\n", 1191 | " \n", 1192 | " total_loss.backward()\n", 1193 | " nn.utils.clip_grad_norm_(model.parameters(),\n", 1194 | " max_norm=50)\n", 1195 | " optimizer.step()\n", 1196 | " global_step += 1\n", 1197 | " \n", 1198 | " if global_step % 100 == 0:\n", 1199 | " print(f'{epoch} Epochs | {i}/{len(train_dataloader)} | reconst_loss: {reconstruction_loss.item():.4g} | '\n", 1200 | " f'kl_loss: {kl_loss:.4g}, total_loss: {total_loss:.4g}, '\n", 1201 | " f'kl_annealing: {kl_annealing(global_step - 1):.4g} ')\n", 1202 | "\n", 1203 | " save_model(epoch, global_step, model)\n", 1204 | " \n", 1205 | "print('Training done!!')" 1206 | ] 1207 | }, 1208 | { 1209 | "cell_type": "markdown", 1210 | "metadata": {}, 1211 | "source": [ 1212 | "## Generate new SMILES\n", 1213 | "\n", 1214 | "* model을 학습한 후에는 학습된 모델을 `load`하여 SMILES를 생성할 준비를 합니다.\n", 1215 | "* `model.generate`함수를 이용하면 새로운 SMILES sequence를 만들수 있습니다.\n", 1216 | "* 여기서는 generation의 각 과정을 하나씩 설명합니다.\n", 1217 | "* 자세한 코드는 [`laiddmg/generate.py`](https://github.com/ilguyi/LAIDD-molecule-generation/blob/main/laiddmg/generate.py)에 나와 있습니다." 1218 | ] 1219 | }, 1220 | { 1221 | "cell_type": "code", 1222 | "execution_count": null, 1223 | "metadata": { 1224 | "ExecuteTime": { 1225 | "end_time": "2021-08-29T06:49:44.372607Z", 1226 | "start_time": "2021-08-29T06:49:44.338414Z" 1227 | } 1228 | }, 1229 | "outputs": [], 1230 | "source": [ 1231 | "checkpoint_dir = training_args.output_dir\n", 1232 | "model = VAEModel.from_pretrained(config, os.path.join(f'{checkpoint_dir}',\n", 1233 | " f'ckpt_{training_args.num_train_epochs:03d}.pt'))\n", 1234 | "model.eval()" 1235 | ] 1236 | }, 1237 | { 1238 | "cell_type": "markdown", 1239 | "metadata": {}, 1240 | "source": [ 1241 | "* 본 수업에서는 시간관계상 미리 학습한 `best_model`을 다운 받아서 씁니다." 1242 | ] 1243 | }, 1244 | { 1245 | "cell_type": "code", 1246 | "execution_count": null, 1247 | "metadata": { 1248 | "ExecuteTime": { 1249 | "end_time": "2021-08-30T11:38:34.112875Z", 1250 | "start_time": "2021-08-30T11:38:27.853124Z" 1251 | } 1252 | }, 1253 | "outputs": [], 1254 | "source": [ 1255 | "!wget 'https://www.dropbox.com/s/751pqnlgwqnqkby/vae_best.tar.gz?dl=0'\n", 1256 | "!tar xvzf vae_best.tar.gz?dl=0\n", 1257 | "!rm -f vae_best.tar.gz?dl=0\n", 1258 | "!mv best_model/ outputs/vae/" 1259 | ] 1260 | }, 1261 | { 1262 | "cell_type": "code", 1263 | "execution_count": null, 1264 | "metadata": { 1265 | "ExecuteTime": { 1266 | "end_time": "2021-08-30T11:38:37.590985Z", 1267 | "start_time": "2021-08-30T11:38:37.528556Z" 1268 | } 1269 | }, 1270 | "outputs": [], 1271 | "source": [ 1272 | "model = VAEModel.from_pretrained(config,\n", 1273 | " os.path.join('./outputs/vae/best_model/best_model.pt'))\n", 1274 | "model = model.to(device)\n", 1275 | "model.eval()" 1276 | ] 1277 | }, 1278 | { 1279 | "cell_type": "code", 1280 | "execution_count": null, 1281 | "metadata": { 1282 | "ExecuteTime": { 1283 | "end_time": "2021-08-30T11:38:40.351879Z", 1284 | "start_time": "2021-08-30T11:38:40.348252Z" 1285 | } 1286 | }, 1287 | "outputs": [], 1288 | "source": [ 1289 | "batch_size_for_generate = 4" 1290 | ] 1291 | }, 1292 | { 1293 | "cell_type": "code", 1294 | "execution_count": null, 1295 | "metadata": { 1296 | "ExecuteTime": { 1297 | "end_time": "2021-08-30T11:38:40.954145Z", 1298 | "start_time": "2021-08-30T11:38:40.682456Z" 1299 | } 1300 | }, 1301 | "outputs": [], 1302 | "source": [ 1303 | "outputs = model.generate(tokenizer=tokenizer,\n", 1304 | " max_length=128,\n", 1305 | " num_return_sequences=batch_size_for_generate,\n", 1306 | " skip_special_tokens=True)" 1307 | ] 1308 | }, 1309 | { 1310 | "cell_type": "code", 1311 | "execution_count": null, 1312 | "metadata": { 1313 | "ExecuteTime": { 1314 | "end_time": "2021-08-30T11:38:40.966093Z", 1315 | "start_time": "2021-08-30T11:38:40.961391Z" 1316 | } 1317 | }, 1318 | "outputs": [], 1319 | "source": [ 1320 | "outputs" 1321 | ] 1322 | }, 1323 | { 1324 | "cell_type": "markdown", 1325 | "metadata": { 1326 | "ExecuteTime": { 1327 | "end_time": "2021-08-20T01:23:46.853695Z", 1328 | "start_time": "2021-08-20T01:23:46.850494Z" 1329 | } 1330 | }, 1331 | "source": [ 1332 | "### generation 과정 step by step으로 알아보기\n", 1333 | "\n", 1334 | "* step 1. `input_ids`변수에 첫 번째 token 데이터인 ` token` 넣기\n", 1335 | "* step 2. prior distribution(Guassian distribution)에서 latent vector `z` 샘플링\n", 1336 | "* step 3. latent vector `z`를 decoder의 initial state로 넣기 위해 Linear레이어 적용\n", 1337 | "* step 4. `input_ids`데이터를 `embedding`에 넣어 embedded input 얻기\n", 1338 | "* step 5. latent vector `z`를 (모든) input data에 concatenate\n", 1339 | "* step 6. concatenate한 input 데이터 gru에 넣기\n", 1340 | "* step 7. `outputs`을 `Linear`레이어를 통과시켜서 `next_token_logits`을 얻기\n", 1341 | "* step 8. `next_token_logits`을 `softmax`를 통해 확률분포를 얻음\n", 1342 | "* step 9. 이 확률분포를 기반한 sampling 작업을 함 (`torch.multinomial`을 이용)\n", 1343 | "* step 10. 실제로 sampling된 값이 `next_tokens`이 되고 이게 다음 스텝의 rnn 인풋으로 쓰임 (`input_ids = next_tokens`)\n", 1344 | "* step 11. step 2 ~ step 10과정을 반복" 1345 | ] 1346 | }, 1347 | { 1348 | "cell_type": "code", 1349 | "execution_count": null, 1350 | "metadata": { 1351 | "ExecuteTime": { 1352 | "end_time": "2021-08-30T11:38:44.953529Z", 1353 | "start_time": "2021-08-30T11:38:44.947756Z" 1354 | } 1355 | }, 1356 | "outputs": [], 1357 | "source": [ 1358 | "model = model.to(device)" 1359 | ] 1360 | }, 1361 | { 1362 | "cell_type": "markdown", 1363 | "metadata": { 1364 | "ExecuteTime": { 1365 | "end_time": "2021-08-23T07:14:20.751167Z", 1366 | "start_time": "2021-08-23T07:14:20.748749Z" 1367 | } 1368 | }, 1369 | "source": [ 1370 | "#### step 1. `input_ids`변수에 ` token` 넣기" 1371 | ] 1372 | }, 1373 | { 1374 | "cell_type": "code", 1375 | "execution_count": null, 1376 | "metadata": { 1377 | "ExecuteTime": { 1378 | "end_time": "2021-08-30T11:38:45.476419Z", 1379 | "start_time": "2021-08-30T11:38:45.472138Z" 1380 | } 1381 | }, 1382 | "outputs": [], 1383 | "source": [ 1384 | "initial_inputs = torch.full((batch_size_for_generate, 1),\n", 1385 | " tokenizer.convert_token_to_id(tokenizer.start_token),\n", 1386 | " dtype=torch.long,\n", 1387 | " device=model.device)\n", 1388 | "generated_sequences = initial_inputs\n", 1389 | "input_ids = initial_inputs" 1390 | ] 1391 | }, 1392 | { 1393 | "cell_type": "code", 1394 | "execution_count": null, 1395 | "metadata": { 1396 | "ExecuteTime": { 1397 | "end_time": "2021-08-30T11:38:45.652039Z", 1398 | "start_time": "2021-08-30T11:38:45.647513Z" 1399 | } 1400 | }, 1401 | "outputs": [], 1402 | "source": [ 1403 | "input_ids" 1404 | ] 1405 | }, 1406 | { 1407 | "cell_type": "markdown", 1408 | "metadata": { 1409 | "ExecuteTime": { 1410 | "end_time": "2021-08-23T07:06:47.477691Z", 1411 | "start_time": "2021-08-23T07:06:47.473449Z" 1412 | } 1413 | }, 1414 | "source": [ 1415 | "#### step 2. prior distribution(Guassian distribution)에서 latent vector `z` 샘플링" 1416 | ] 1417 | }, 1418 | { 1419 | "cell_type": "code", 1420 | "execution_count": null, 1421 | "metadata": { 1422 | "ExecuteTime": { 1423 | "end_time": "2021-08-30T11:38:45.980539Z", 1424 | "start_time": "2021-08-30T11:38:45.976903Z" 1425 | } 1426 | }, 1427 | "outputs": [], 1428 | "source": [ 1429 | "z = model.sample_gaussian_dist(batch_size_for_generate) # z: [batch_size, latent_dim]\n", 1430 | "z_ = z.unsqueeze(1) # z_: [batch_size, 1, latent_dim]\n", 1431 | "# z_: step 5에서 input과 concatenate 하기 위해 shape을 맞춰줌" 1432 | ] 1433 | }, 1434 | { 1435 | "cell_type": "code", 1436 | "execution_count": null, 1437 | "metadata": { 1438 | "ExecuteTime": { 1439 | "end_time": "2021-08-30T11:38:46.136095Z", 1440 | "start_time": "2021-08-30T11:38:46.132563Z" 1441 | } 1442 | }, 1443 | "outputs": [], 1444 | "source": [ 1445 | "print(z.shape)\n", 1446 | "print(z_.shape)" 1447 | ] 1448 | }, 1449 | { 1450 | "cell_type": "markdown", 1451 | "metadata": { 1452 | "ExecuteTime": { 1453 | "end_time": "2021-08-23T07:08:32.158298Z", 1454 | "start_time": "2021-08-23T07:08:32.153435Z" 1455 | } 1456 | }, 1457 | "source": [ 1458 | "#### step 3. latent vector `z`를 decoder의 initial state로 넣기 위해 Linear레이어 적용" 1459 | ] 1460 | }, 1461 | { 1462 | "cell_type": "code", 1463 | "execution_count": null, 1464 | "metadata": { 1465 | "ExecuteTime": { 1466 | "end_time": "2021-08-30T11:38:46.445527Z", 1467 | "start_time": "2021-08-30T11:38:46.441594Z" 1468 | } 1469 | }, 1470 | "outputs": [], 1471 | "source": [ 1472 | "hiddens = model.decoder.z2hidden(z) # hiddens: [batch_size, hidden_dim]\n", 1473 | "hiddens = hiddens.unsqueeze(0).repeat(model.config.decoder_num_layers, 1, 1)" 1474 | ] 1475 | }, 1476 | { 1477 | "cell_type": "code", 1478 | "execution_count": null, 1479 | "metadata": { 1480 | "ExecuteTime": { 1481 | "end_time": "2021-08-30T11:38:46.615648Z", 1482 | "start_time": "2021-08-30T11:38:46.612210Z" 1483 | } 1484 | }, 1485 | "outputs": [], 1486 | "source": [ 1487 | "print(hiddens.shape) # [decoder.num_layers, batch_size, hidden_dim]" 1488 | ] 1489 | }, 1490 | { 1491 | "cell_type": "markdown", 1492 | "metadata": { 1493 | "ExecuteTime": { 1494 | "end_time": "2021-08-23T07:10:33.802845Z", 1495 | "start_time": "2021-08-23T07:10:33.799144Z" 1496 | } 1497 | }, 1498 | "source": [ 1499 | "#### step 4. `input_ids`데이터를 `embedding`에 넣어 embedded input 얻기" 1500 | ] 1501 | }, 1502 | { 1503 | "cell_type": "code", 1504 | "execution_count": null, 1505 | "metadata": { 1506 | "ExecuteTime": { 1507 | "end_time": "2021-08-30T11:38:46.945212Z", 1508 | "start_time": "2021-08-30T11:38:46.942350Z" 1509 | } 1510 | }, 1511 | "outputs": [], 1512 | "source": [ 1513 | "x = model.embeddings(input_ids) # x: [batch_size, 1, embedding_dim]" 1514 | ] 1515 | }, 1516 | { 1517 | "cell_type": "code", 1518 | "execution_count": null, 1519 | "metadata": { 1520 | "ExecuteTime": { 1521 | "end_time": "2021-08-30T11:38:47.181696Z", 1522 | "start_time": "2021-08-30T11:38:47.177207Z" 1523 | } 1524 | }, 1525 | "outputs": [], 1526 | "source": [ 1527 | "x.shape" 1528 | ] 1529 | }, 1530 | { 1531 | "cell_type": "markdown", 1532 | "metadata": { 1533 | "ExecuteTime": { 1534 | "end_time": "2021-08-23T07:15:20.068978Z", 1535 | "start_time": "2021-08-23T07:15:20.064783Z" 1536 | } 1537 | }, 1538 | "source": [ 1539 | "#### step 5. latent vector `z`를 (모든) input data에 concatenate\n", 1540 | "\n", 1541 | "* 매 token 마다 latent vector의 정보를 추가하여 성능을 높이기 위해서" 1542 | ] 1543 | }, 1544 | { 1545 | "cell_type": "code", 1546 | "execution_count": null, 1547 | "metadata": { 1548 | "ExecuteTime": { 1549 | "end_time": "2021-08-30T11:38:47.515270Z", 1550 | "start_time": "2021-08-30T11:38:47.511911Z" 1551 | } 1552 | }, 1553 | "outputs": [], 1554 | "source": [ 1555 | "x = torch.cat((x, z_), dim=-1) # x: [batch_size, 1, embedding_dim + latent_dim]" 1556 | ] 1557 | }, 1558 | { 1559 | "cell_type": "markdown", 1560 | "metadata": {}, 1561 | "source": [ 1562 | "#### step 6. concatenate한 input 데이터 gru에 넣기" 1563 | ] 1564 | }, 1565 | { 1566 | "cell_type": "code", 1567 | "execution_count": null, 1568 | "metadata": { 1569 | "ExecuteTime": { 1570 | "end_time": "2021-08-30T11:38:47.855560Z", 1571 | "start_time": "2021-08-30T11:38:47.851942Z" 1572 | } 1573 | }, 1574 | "outputs": [], 1575 | "source": [ 1576 | "x, hiddens = model.decoder.gru(x, hiddens)" 1577 | ] 1578 | }, 1579 | { 1580 | "cell_type": "markdown", 1581 | "metadata": { 1582 | "ExecuteTime": { 1583 | "end_time": "2021-08-23T07:30:43.052727Z", 1584 | "start_time": "2021-08-23T07:30:43.050241Z" 1585 | } 1586 | }, 1587 | "source": [ 1588 | "#### step 7. `outputs`을 `Linear`레이어를 통과시켜서 `next_token_logits`을 얻기" 1589 | ] 1590 | }, 1591 | { 1592 | "cell_type": "code", 1593 | "execution_count": null, 1594 | "metadata": { 1595 | "ExecuteTime": { 1596 | "end_time": "2021-08-30T11:38:48.186674Z", 1597 | "start_time": "2021-08-30T11:38:48.182332Z" 1598 | } 1599 | }, 1600 | "outputs": [], 1601 | "source": [ 1602 | "logits = model.decoder.fc(x)\n", 1603 | "next_token_logits = logits.squeeze(1)" 1604 | ] 1605 | }, 1606 | { 1607 | "cell_type": "code", 1608 | "execution_count": null, 1609 | "metadata": { 1610 | "ExecuteTime": { 1611 | "end_time": "2021-08-30T11:38:48.361825Z", 1612 | "start_time": "2021-08-30T11:38:48.357422Z" 1613 | } 1614 | }, 1615 | "outputs": [], 1616 | "source": [ 1617 | "logits.shape" 1618 | ] 1619 | }, 1620 | { 1621 | "cell_type": "code", 1622 | "execution_count": null, 1623 | "metadata": { 1624 | "ExecuteTime": { 1625 | "end_time": "2021-08-30T11:38:48.506411Z", 1626 | "start_time": "2021-08-30T11:38:48.501838Z" 1627 | } 1628 | }, 1629 | "outputs": [], 1630 | "source": [ 1631 | "next_token_logits.shape" 1632 | ] 1633 | }, 1634 | { 1635 | "cell_type": "markdown", 1636 | "metadata": { 1637 | "ExecuteTime": { 1638 | "end_time": "2021-08-20T01:52:29.148625Z", 1639 | "start_time": "2021-08-20T01:52:29.145709Z" 1640 | } 1641 | }, 1642 | "source": [ 1643 | "#### step 8. `next_token_logits`을 `softmax`를 통해 확률분포를 얻음" 1644 | ] 1645 | }, 1646 | { 1647 | "cell_type": "code", 1648 | "execution_count": null, 1649 | "metadata": { 1650 | "ExecuteTime": { 1651 | "end_time": "2021-08-30T11:38:48.869908Z", 1652 | "start_time": "2021-08-30T11:38:48.866828Z" 1653 | } 1654 | }, 1655 | "outputs": [], 1656 | "source": [ 1657 | "probabilities = F.softmax(next_token_logits, dim=-1)" 1658 | ] 1659 | }, 1660 | { 1661 | "cell_type": "code", 1662 | "execution_count": null, 1663 | "metadata": { 1664 | "ExecuteTime": { 1665 | "end_time": "2021-08-30T11:38:49.214826Z", 1666 | "start_time": "2021-08-30T11:38:49.207372Z" 1667 | } 1668 | }, 1669 | "outputs": [], 1670 | "source": [ 1671 | "probabilities[0]" 1672 | ] 1673 | }, 1674 | { 1675 | "cell_type": "markdown", 1676 | "metadata": { 1677 | "ExecuteTime": { 1678 | "end_time": "2021-08-20T01:52:29.586382Z", 1679 | "start_time": "2021-08-20T01:52:29.581315Z" 1680 | } 1681 | }, 1682 | "source": [ 1683 | "#### step 9. 이 확률분포를 기반한 sampling 작업을 함 ([`torch.multinomial`](https://pytorch.org/docs/stable/generated/torch.multinomial.html?highlight=multinomial#torch.multinomial)을 이용)" 1684 | ] 1685 | }, 1686 | { 1687 | "cell_type": "code", 1688 | "execution_count": null, 1689 | "metadata": { 1690 | "ExecuteTime": { 1691 | "end_time": "2021-08-30T11:38:49.871639Z", 1692 | "start_time": "2021-08-30T11:38:49.862204Z" 1693 | } 1694 | }, 1695 | "outputs": [], 1696 | "source": [ 1697 | "next_tokens = torch.multinomial(probabilities, num_samples=1)\n", 1698 | "next_tokens" 1699 | ] 1700 | }, 1701 | { 1702 | "cell_type": "markdown", 1703 | "metadata": { 1704 | "ExecuteTime": { 1705 | "end_time": "2021-08-20T01:49:41.633366Z", 1706 | "start_time": "2021-08-20T01:49:41.626309Z" 1707 | } 1708 | }, 1709 | "source": [ 1710 | "참고 `tokenizer.vocab`\n", 1711 | "\n", 1712 | "```python\n", 1713 | "{('#', 4), ('(', 5), (')', 6), ('-', 7),\n", 1714 | " ('1', 8), ('2', 9), ('3', 10), ('4', 11), ('5', 12), ('6', 13), ('=', 14),\n", 1715 | " ('B', 15), ('C', 16), ('F', 17), ('H', 18), ('N', 19), ('O', 20), ('S', 21),\n", 1716 | " ('[', 22), (']', 23), ('c', 24), ('l', 25), ('n', 26), ('o', 27), ('r', 28), ('s', 29)}\n", 1717 | "```" 1718 | ] 1719 | }, 1720 | { 1721 | "cell_type": "markdown", 1722 | "metadata": { 1723 | "ExecuteTime": { 1724 | "end_time": "2021-08-23T07:36:08.564746Z", 1725 | "start_time": "2021-08-23T07:36:08.561366Z" 1726 | } 1727 | }, 1728 | "source": [ 1729 | "#### step 10. 실제로 sampling된 값이 `next_tokens`이 되고 이게 다음 스텝의 rnn 인풋으로 쓰임 (`input_ids = next_tokens`)" 1730 | ] 1731 | }, 1732 | { 1733 | "cell_type": "code", 1734 | "execution_count": null, 1735 | "metadata": { 1736 | "ExecuteTime": { 1737 | "end_time": "2021-08-30T11:38:52.510608Z", 1738 | "start_time": "2021-08-30T11:38:52.503629Z" 1739 | } 1740 | }, 1741 | "outputs": [], 1742 | "source": [ 1743 | "inputs_ids = next_tokens\n", 1744 | "generated_sequences = torch.cat((generated_sequences, next_tokens), dim=1)\n", 1745 | "generated_sequences" 1746 | ] 1747 | }, 1748 | { 1749 | "cell_type": "markdown", 1750 | "metadata": { 1751 | "ExecuteTime": { 1752 | "end_time": "2021-08-20T02:04:21.601326Z", 1753 | "start_time": "2021-08-20T02:04:21.598478Z" 1754 | } 1755 | }, 1756 | "source": [ 1757 | "#### 위의 과정을 모듈화해서 `generate`함수를 만들었습니다." 1758 | ] 1759 | }, 1760 | { 1761 | "cell_type": "code", 1762 | "execution_count": null, 1763 | "metadata": { 1764 | "ExecuteTime": { 1765 | "end_time": "2021-08-30T11:38:55.639124Z", 1766 | "start_time": "2021-08-30T11:38:55.167476Z" 1767 | } 1768 | }, 1769 | "outputs": [], 1770 | "source": [ 1771 | "outputs = model.generate(tokenizer=tokenizer,\n", 1772 | " max_length=128,\n", 1773 | " #num_return_sequences=batch_size_for_generate,\n", 1774 | " num_return_sequences=256,\n", 1775 | " skip_special_tokens=True)" 1776 | ] 1777 | }, 1778 | { 1779 | "cell_type": "code", 1780 | "execution_count": null, 1781 | "metadata": { 1782 | "ExecuteTime": { 1783 | "end_time": "2021-08-30T11:38:56.681864Z", 1784 | "start_time": "2021-08-30T11:38:56.628297Z" 1785 | } 1786 | }, 1787 | "outputs": [], 1788 | "source": [ 1789 | "import rdkit\n", 1790 | "from rdkit import Chem\n", 1791 | "from rdkit.Chem.Draw import IPythonConsole" 1792 | ] 1793 | }, 1794 | { 1795 | "cell_type": "code", 1796 | "execution_count": null, 1797 | "metadata": { 1798 | "ExecuteTime": { 1799 | "end_time": "2021-08-30T11:38:57.046557Z", 1800 | "start_time": "2021-08-30T11:38:57.016948Z" 1801 | }, 1802 | "scrolled": true 1803 | }, 1804 | "outputs": [], 1805 | "source": [ 1806 | "mols = []\n", 1807 | "for s in outputs:\n", 1808 | " try:\n", 1809 | " mol = Chem.MolFromSmiles(s)\n", 1810 | " except:\n", 1811 | " pass\n", 1812 | " if mol is not None:\n", 1813 | " mols.append(mol)" 1814 | ] 1815 | }, 1816 | { 1817 | "cell_type": "code", 1818 | "execution_count": null, 1819 | "metadata": { 1820 | "ExecuteTime": { 1821 | "end_time": "2021-08-30T11:38:57.871409Z", 1822 | "start_time": "2021-08-30T11:38:57.867306Z" 1823 | } 1824 | }, 1825 | "outputs": [], 1826 | "source": [ 1827 | "len(mols)" 1828 | ] 1829 | }, 1830 | { 1831 | "cell_type": "code", 1832 | "execution_count": null, 1833 | "metadata": { 1834 | "ExecuteTime": { 1835 | "end_time": "2021-08-30T11:38:58.526726Z", 1836 | "start_time": "2021-08-30T11:38:58.517659Z" 1837 | } 1838 | }, 1839 | "outputs": [], 1840 | "source": [ 1841 | "mols[0]" 1842 | ] 1843 | }, 1844 | { 1845 | "cell_type": "code", 1846 | "execution_count": null, 1847 | "metadata": { 1848 | "ExecuteTime": { 1849 | "end_time": "2021-08-30T11:38:58.910300Z", 1850 | "start_time": "2021-08-30T11:38:58.902537Z" 1851 | } 1852 | }, 1853 | "outputs": [], 1854 | "source": [ 1855 | "mols[1]" 1856 | ] 1857 | }, 1858 | { 1859 | "cell_type": "code", 1860 | "execution_count": null, 1861 | "metadata": { 1862 | "ExecuteTime": { 1863 | "end_time": "2021-08-30T11:38:59.324041Z", 1864 | "start_time": "2021-08-30T11:38:59.316513Z" 1865 | } 1866 | }, 1867 | "outputs": [], 1868 | "source": [ 1869 | "mols[2]" 1870 | ] 1871 | }, 1872 | { 1873 | "cell_type": "code", 1874 | "execution_count": null, 1875 | "metadata": {}, 1876 | "outputs": [], 1877 | "source": [] 1878 | }, 1879 | { 1880 | "cell_type": "code", 1881 | "execution_count": null, 1882 | "metadata": {}, 1883 | "outputs": [], 1884 | "source": [] 1885 | } 1886 | ], 1887 | "metadata": { 1888 | "kernelspec": { 1889 | "display_name": "Python [conda env:laiddmg] *", 1890 | "language": "python", 1891 | "name": "conda-env-laiddmg-py" 1892 | }, 1893 | "language_info": { 1894 | "codemirror_mode": { 1895 | "name": "ipython", 1896 | "version": 3 1897 | }, 1898 | "file_extension": ".py", 1899 | "mimetype": "text/x-python", 1900 | "name": "python", 1901 | "nbconvert_exporter": "python", 1902 | "pygments_lexer": "ipython3", 1903 | "version": "3.7.11" 1904 | }, 1905 | "toc": { 1906 | "base_numbering": 1, 1907 | "nav_menu": {}, 1908 | "number_sections": true, 1909 | "sideBar": true, 1910 | "skip_h1_title": false, 1911 | "title_cell": "Table of Contents", 1912 | "title_sidebar": "Contents", 1913 | "toc_cell": false, 1914 | "toc_position": {}, 1915 | "toc_section_display": true, 1916 | "toc_window_display": false 1917 | } 1918 | }, 1919 | "nbformat": 4, 1920 | "nbformat_minor": 4 1921 | } 1922 | --------------------------------------------------------------------------------